Skip to content

Commit

Permalink
Merge pull request #285 from seoklab/jnooree/issue-283
Browse files Browse the repository at this point in the history
feat(python/fmt): create python writer interfaces
  • Loading branch information
jnooree authored Apr 11, 2024
2 parents dea7728 + 1f42945 commit 48caa15
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 1 deletion.
15 changes: 15 additions & 0 deletions python/docs/nuri.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,25 @@ Submodules
Top-level Functions
-------------------

Readers
-------

.. autofunction:: nuri.readfile

.. autofunction:: nuri.readstring

Writers
-------

These functions all release the GIL and are thread-safe. Thread-based
parallelism is recommended for writing multiple molecules in parallel.

.. autofunction:: nuri.to_smiles

.. autofunction:: nuri.to_mol2

.. autofunction:: nuri.to_sdf

--------------------
Top-level Attributes
--------------------
Expand Down
1 change: 1 addition & 0 deletions python/include/nuri/python/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ inline int py_check_index(int size, int idx, const char *onerror) {
}

constexpr inline static py::keep_alive<0, 1> kReturnsSubobject {};
constexpr inline static py::call_guard<py::gil_scoped_release> kThreadSafe {};

template <class T, class Size, class Getter, class Iter>
py::class_<T> &add_sequence_interface(py::class_<T> &cls, Size size,
Expand Down
5 changes: 4 additions & 1 deletion python/src/nuri/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
__all__ = [
"readfile",
"readstring",
"to_smiles",
"to_mol2",
"to_sdf",
"periodic_table",
"__version__",
]
Expand All @@ -22,4 +25,4 @@

from . import _log_adapter
from .core import periodic_table
from .fmt import readfile, readstring
from .fmt import readfile, readstring, to_smiles, to_mol2, to_sdf
80 changes: 80 additions & 0 deletions python/src/nuri/fmt/fmt_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <fstream>
#include <istream>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <string_view>
Expand All @@ -24,6 +25,9 @@

#include "nuri/core/molecule.h"
#include "nuri/fmt/base.h"
#include "nuri/fmt/mol2.h"
#include "nuri/fmt/sdf.h"
#include "nuri/fmt/smiles.h"
#include "nuri/python/core/core_module.h"
#include "nuri/python/exception.h"
#include "nuri/python/utils.h"
Expand Down Expand Up @@ -101,6 +105,21 @@ class PyMoleculeReader {
bool skip_on_error_;
};

template <class F, class... Args>
std::string try_write(const Molecule &mol, std::string_view fmt, F writer,
Args &&...args) {
std::string buf;
if (!writer(buf, mol, std::forward<Args>(args)...))
throw py::value_error(absl::StrCat("Failed to convert molecule to ", fmt));
return buf;
}

int writer_check_conf(const Molecule &mol, std::optional<int> oconf) {
if (!oconf)
return -1;
return check_conf(mol, *oconf);
}

namespace fs = std::filesystem;

NURI_PYTHON_MODULE(m) {
Expand Down Expand Up @@ -167,6 +186,67 @@ The returned object is an iterable of molecules.
>>> for mol in nuri.readstring("smi", "C"):
... print(mol[0].atomic_number)
6
)doc");

m.def(
"to_smiles",
[](const PyMol &mol) {
return try_write(*mol, "smiles", write_smiles, false);
},
py::arg("mol"), kThreadSafe, R"doc(
Convert a molecule to SMILES string.
:param mol: The molecule to convert.
:raises ValueError: If the conversion fails.
)doc")
.def(
"to_mol2",
[](const PyMol &mol, std::optional<int> oconf) {
int conf = writer_check_conf(*mol, oconf);
return try_write(*mol, "Mol2", write_mol2, conf);
},
py::arg("mol"), py::arg("conf") = py::none(), kThreadSafe, R"doc(
Convert a molecule to Mol2 string.
:param mol: The molecule to convert.
:param conf: The conformation to convert. If not specified, writes all
conformations. Ignored if the molecule has no conformations.
:raises IndexError: If the molecule has any conformations and `conf` is out of
range.
:raises ValueError: If the conversion fails.
)doc")
.def(
"to_sdf",
[](const PyMol &mol, std::optional<int> oconf,
std::optional<int> oversion) {
int conf = writer_check_conf(*mol, oconf);

SDFVersion version = SDFVersion::kAutomatic;
if (oversion) {
int user_version = *oversion;
if (user_version == 2000)
version = SDFVersion::kV2000;
else if (user_version == 3000)
version = SDFVersion::kV3000;
else
throw py::value_error(
absl::StrCat("Invalid SDF version: ", user_version));
}

return try_write(*mol, "SDF", write_sdf, conf, version);
},
py::arg("mol"), py::arg("conf") = py::none(),
py::arg("version") = py::none(), kThreadSafe, R"doc(
Convert a molecule to SDF string.
:param mol: The molecule to convert.
:param conf: The conformation to convert. If not specified, writes all
conformations. Ignored if the molecule has no conformations.
:param version: The SDF version to write. If not specified, the version is
automatically determined. Only 2000 and 3000 are supported.
:raises IndexError: If the molecule has any conformations and `conf` is out of
range.
:raises ValueError: If the conversion fails, or if the version is invalid.
)doc");
}
} // namespace
Expand Down
File renamed without changes.
22 changes: 22 additions & 0 deletions python/test/fmt/mol2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from pathlib import Path
from typing import List

import numpy as np
import pytest

import nuri
from nuri.core import Molecule, Hyb

Expand Down Expand Up @@ -108,3 +111,22 @@ def test_mol2_file(tmp_path: Path):
def test_mol2_str():
mols = list(nuri.readstring("mol2", mol2_data))
_verify_mols(mols)

mol2_re = "".join(map(nuri.to_mol2, mols))
mols_re = list(nuri.readstring("mol2", mol2_re))
_verify_mols(mols_re)


def test_mol2_options(mol3d: Molecule):
mol2s = nuri.to_mol2(mol3d)

mols = list(nuri.readstring("mol2", mol2s))
assert len(mols) == 2

mol2s = nuri.to_mol2(mol3d, conf=1)
mols = list(nuri.readstring("mol2", mol2s))
assert len(mols) == 1
assert np.allclose(mol3d.get_conf(1), mols[0].get_conf(0), atol=1e-3)

with pytest.raises(IndexError):
mol2s = nuri.to_mol2(mol3d, conf=2)
33 changes: 33 additions & 0 deletions python/test/fmt/sdf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from pathlib import Path
from typing import List

import numpy as np
import pytest

import nuri
from nuri.core import Molecule, Hyb

Expand Down Expand Up @@ -102,3 +105,33 @@ def test_sdf_file(tmp_path: Path):
def test_sdf_str():
mols = list(nuri.readstring("sdf", sdf_data))
_verify_mols(mols)

sdf_re = "".join(map(nuri.to_sdf, mols))
mols_re = list(nuri.readstring("sdf", sdf_re))
_verify_mols(mols_re)


def test_sdf_options(mol3d: Molecule):
sdfs = nuri.to_sdf(mol3d)

mols = list(nuri.readstring("sdf", sdfs))
assert len(mols) == 2

sdfs = nuri.to_sdf(mol3d, version=2000)
mols = list(nuri.readstring("sdf", sdfs))
assert len(mols) == 2

sdfs = nuri.to_sdf(mol3d, version=3000)
mols = list(nuri.readstring("sdf", sdfs))
assert len(mols) == 2

sdfs = nuri.to_sdf(mol3d, conf=1)
mols = list(nuri.readstring("sdf", sdfs))
assert len(mols) == 1
assert np.allclose(mol3d.get_conf(1), mols[0].get_conf(0), atol=1e-3)

with pytest.raises(IndexError):
sdfs = nuri.to_sdf(mol3d, conf=2)

with pytest.raises(ValueError, match="Invalid SDF version"):
sdfs = nuri.to_sdf(mol3d, version=9999)
4 changes: 4 additions & 0 deletions python/test/fmt/smi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ def test_smiles_file(tmp_path: Path):
def test_smiles_str():
mols = list(nuri.readstring("smi", smi_data))
_verify_mols(mols)

smiles_re = "".join(map(nuri.to_smiles, mols))
mols_re = list(nuri.readstring("smi", smiles_re))
_verify_mols(mols_re)

0 comments on commit 48caa15

Please sign in to comment.