Skip to content

Commit

Permalink
fix drop duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
pstjohn committed Sep 22, 2021
1 parent c466a1b commit 80792b8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
3 changes: 1 addition & 2 deletions alfabet/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def predict_bdes(smiles, draw=False):
on=['molecule', 'bond_index'], how='left')

# Drop duplicate entries and sort from weakest to strongest
frag_df = frag_df.sort_values('bde_pred').drop_duplicates(
['fragment1', 'fragment2']).reset_index(drop=True)
frag_df = frag_df.sort_values('bde_pred').reset_index(drop=True)

# Draw SVGs
if draw:
Expand Down
12 changes: 11 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np

from alfabet import model


def test_predict():
results = model.predict(['CC', 'NCCO', 'CF', 'B'], verbose=False)

Expand All @@ -16,4 +18,12 @@ def test_predict():

np.testing.assert_allclose(
results[results.molecule == 'NCCO'].bde_pred,
[90.0, 82.1, 98.2, 99.3, 92.1, 92.5, 105.2], atol=1., rtol=.05)
[90.0, 82.1, 98.2, 99.3, 92.1, 92.5, 105.2], atol=1., rtol=.05)


def test_duplicates():
results = model.predict(['c1ccccc1'], verbose=False, drop_duplicates=True)
assert len(results) == 1

results = model.predict(['c1ccccc1'], verbose=False, drop_duplicates=False)
assert len(results) == 6

0 comments on commit 80792b8

Please sign in to comment.