Skip to content

Commit

Permalink
Merge pull request #13 from NREL/fix-noncanon-smiles
Browse files Browse the repository at this point in the history
handle non-canonical smiles inputs
  • Loading branch information
pstjohn authored Aug 24, 2022
2 parents ce2e890 + a951545 commit 688d079
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 29 deletions.
21 changes: 14 additions & 7 deletions alfabet/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,28 @@ def molH(self) -> Type[rdkit.Chem.Mol]:
@property
def smiles(self) -> str:
if (self._smiles is None) or not self._is_canon:
self._smiles = rdkit.Chem.MolToSmiles(self._mol)
self._smiles = rdkit.Chem.MolToSmiles(self.mol)
return self._smiles


def get_fragments(smiles: str, drop_duplicates: bool = False) -> pd.DataFrame:
df = pd.DataFrame(fragment_iterator(smiles))
def get_fragments(
input_molecule: Molecule, drop_duplicates: bool = False
) -> pd.DataFrame:
df = pd.DataFrame(fragment_iterator(input_molecule))
if drop_duplicates:
df = df.drop_duplicates(["fragment1", "fragment2"]).reset_index(drop=True)
return df


def fragment_iterator(smiles: str, skip_warnings: bool = False) -> Iterator[Dict]:
def fragment_iterator(
input_molecule: str, skip_warnings: bool = False
) -> Iterator[Dict]:

input_molecule = Molecule(smiles=smiles)
mol_stereo = count_stereocenters(input_molecule)
if (mol_stereo["atom_unassigned"] != 0) or (mol_stereo["bond_unassigned"] != 0):
logging.warning(f"Molecule {smiles} has undefined stereochemistry")
logging.warning(
f"Molecule {input_molecule.smiles} has undefined stereochemistry"
)
if skip_warnings:
return

Expand Down Expand Up @@ -110,7 +115,9 @@ def fragment_iterator(smiles: str, skip_warnings: bool = False) -> Iterator[Dict

except ValueError:
logging.error(
"Fragmentation error with {}, bond {}".format(smiles, bond.GetIdx())
"Fragmentation error with {}, bond {}".format(
input_molecule.smiles, bond.GetIdx()
)
)
continue

Expand Down
26 changes: 15 additions & 11 deletions alfabet/model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from typing import List

import pandas as pd
import rdkit.Chem
from nfp.frameworks import tf
from tqdm import tqdm

from alfabet.fragment import get_fragments
from alfabet.fragment import Molecule, get_fragments
from alfabet.prediction import bde_dft, model, validate_inputs
from alfabet.preprocessor import get_features, preprocessor


def get_max_bonds(smiles_list):
def num_bonds(smiles):
mol = rdkit.Chem.MolFromSmiles(smiles)
molH = rdkit.Chem.AddHs(mol)
def get_max_bonds(molecule_list: List[Molecule]):
def num_bonds(molecule):
molH = rdkit.Chem.AddHs(molecule.mol)
return molH.GetNumBonds()

return max((num_bonds(smiles) for smiles in smiles_list))
return max((num_bonds(molecule) for molecule in molecule_list))


def predict(smiles_list, drop_duplicates=True, batch_size=1, verbose=False):
Expand Down Expand Up @@ -47,18 +48,21 @@ def predict(smiles_list, drop_duplicates=True, batch_size=1, verbose=False):
domain of validity
"""

molecule_list = [Molecule(smiles=smiles) for smiles in smiles_list]
smiles_list = [mol.smiles for mol in molecule_list]

pred_df = pd.concat(
(
get_fragments(smiles, drop_duplicates=drop_duplicates)
for smiles in tqdm(smiles_list, disable=not verbose)
get_fragments(mol, drop_duplicates=drop_duplicates)
for mol in tqdm(molecule_list, disable=not verbose)
)
)

max_bonds = get_max_bonds(smiles_list)
max_bonds = get_max_bonds(molecule_list)
input_dataset = tf.data.Dataset.from_generator(
lambda: (
get_features(smiles, max_num_edges=2 * max_bonds)
for smiles in tqdm(smiles_list, disable=not verbose)
get_features(mol.smiles, max_num_edges=2 * max_bonds)
for mol in tqdm(molecule_list, disable=not verbose)
),
output_signature=preprocessor.output_signature,
).cache()
Expand Down
29 changes: 18 additions & 11 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
import numpy as np

from alfabet import model


def test_predict():
results = model.predict(['CC', 'NCCO', 'CF', 'B'])
results = model.predict(["CC", "NCCO", "CF", "B"])

assert not results[results.molecule == 'B'].is_valid.any()
assert results[results.molecule != 'B'].is_valid.all()
assert not results[results.molecule == "B"].is_valid.any()
assert results[results.molecule != "B"].is_valid.all()

# Should be less than 1 kcal/mol on this easy set
assert (results.bde_pred - results.bde).abs().mean() < 1.
assert (results.bde_pred - results.bde).abs().mean() < 1.0

np.testing.assert_allclose(
results[results.molecule == 'CC'].bde_pred,
[90.7, 99.8], atol=1., rtol=.05)
results[results.molecule == "CC"].bde_pred, [90.7, 99.8], atol=1.0, rtol=0.05
)

np.testing.assert_allclose(
results[results.molecule == 'NCCO'].bde_pred,
[90.0, 82.1, 98.2, 99.3, 92.1, 92.5, 105.2], atol=1., rtol=.05)
results[results.molecule == "NCCO"].bde_pred,
[90.0, 82.1, 98.2, 99.3, 92.1, 92.5, 105.2],
atol=1.0,
rtol=0.05,
)


def test_duplicates():
results = model.predict(['c1ccccc1'], drop_duplicates=True)
results = model.predict(["c1ccccc1"], drop_duplicates=True)
assert len(results) == 1

results = model.predict(['c1ccccc1'], drop_duplicates=False)
results = model.predict(["c1ccccc1"], drop_duplicates=False)
assert len(results) == 6


def test_non_canonical_smiles():
smiles = "CC(=O)OCC1=C\CC/C(C)=C/CC[C@@]2(C)CC[C@@](C(C)C)(/C=C/1)O2"
assert len(model.predict([smiles])) == 24

0 comments on commit 688d079

Please sign in to comment.