Skip to content

Commit

Permalink
hotfix custom interaction class error msg
Browse files Browse the repository at this point in the history
  • Loading branch information
cbouy committed Jun 11, 2021
1 parent 2153f71 commit 6e5162b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 39 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed
### Fixed

## [0.3.3] - 2021-06-11
### Changed
- Custom interactions must return three values: a boolean for the interaction,
and the indices of residue atoms responsible for the interaction
### Fixed
- Custom interactions that only returned a single value instead of three would
raise an uninformative error message


## [0.3.2] - 2021-06-11
### Added
- LigNetwork: an interaction diagram with atomistic details for the ligand and
Expand Down
23 changes: 2 additions & 21 deletions docs/notebooks/how-to.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@
"\n",
"This method takes exactly two positional arguments (and as many named arguments as you need): a ligand Residue or Molecule and a protein Residue or Molecule (in this order).\n",
"\n",
"* **Return value(s) for the `detect` method**\n",
"* **Return values for the `detect` method**\n",
"\n",
"There are two possibilities here, depending on whether or not you want to access the indices of atoms responsible for the interaction. If you don't need this information, just return `True` if the interaction is detected, `False` otherwise. If you need to access atomic indices, you must return the following items in this order: \n",
"You must return the following items in this order: \n",
"\n",
" * `True` or `False` for the detection of the interaction\n",
" * The index of the ligand atom, or None if not detected\n",
Expand All @@ -239,25 +239,6 @@
"source": [
"from scipy.spatial import distance_matrix\n",
"\n",
"# without atom indices\n",
"class CloseContact(plf.interactions.Interaction):\n",
" def detect(self, res1, res2, threshold=2.0):\n",
" dist_matrix = distance_matrix(res1.xyz, res2.xyz)\n",
" if (dist_matrix <= threshold).any():\n",
" return True\n",
" return False\n",
"\n",
"fp = plf.Fingerprint()\n",
"fp.closecontact(lmol, pmol[\"ASP129.A\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# with atom indices\n",
"class CloseContact(plf.interactions.Interaction):\n",
" def detect(self, res1, res2, threshold=2.0):\n",
" dist_matrix = distance_matrix(res1.xyz, res2.xyz)\n",
Expand Down
19 changes: 14 additions & 5 deletions prolif/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@

def _return_first_element(f):
"""Modifies the return signature of a function by forcing it to return
only the first element if multiple values were returned.
only the first element when multiple values are returned
Raises
------
TypeError
If the function doesn't return three values
Notes
-----
Expand All @@ -57,10 +62,13 @@ def _return_first_element(f):
def wrapper(*args, **kwargs):
results = f(*args, **kwargs)
try:
value, *rest = results
except TypeError:
value = results
return value
bool_, lig_idx, prot_idx = results
except (TypeError, ValueError):
raise TypeError(
"Incorrect function signature: the interaction class must "
"return 3 values (boolean, int, int)"
) from None
return bool_
return wrapper


Expand Down Expand Up @@ -218,6 +226,7 @@ def bitvector_atoms(self, res1, res2):
A list containing indices for the protein atoms responsible for
each interaction
.. versionchanged:: 0.3.2
Atom indices are returned as two separate lists instead of a single
list of tuples
Expand Down
35 changes: 22 additions & 13 deletions tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,32 @@

class Dummy(Interaction):
def detect(self, res1, res2):
return 1, 2, 3
return True, 4, 2


def func_return_single_val():
return 0
def return_value(*args):
return args if len(args) > 1 else args[0]


def test_wrapper_return():
foo = Dummy().detect
bar = _return_first_element(foo)
assert foo("foo", "bar") == (1, 2, 3)
assert bar("foo", "bar") == 1
assert bar.__wrapped__("foo", "bar") == (1, 2, 3)
baz = _return_first_element(func_return_single_val)
assert baz() == 0
assert baz.__wrapped__() == 0

detect = Dummy().detect
mod = _return_first_element(detect)
assert detect("foo", "bar") == (True, 4, 2)
assert mod("foo", "bar") is True
assert mod.__wrapped__("foo", "bar") == (True, 4, 2)


@pytest.mark.parametrize("returned", [
True,
(True,),
(True, 4)
])
def test_wrapper_incorrect_return(returned):
mod = _return_first_element(return_value)
assert mod.__wrapped__(returned) == returned
with pytest.raises(TypeError,
match="Incorrect function signature"):
mod(returned)

class TestFingerprint:
@pytest.fixture
Expand Down Expand Up @@ -60,7 +69,7 @@ def test_n_interactions(self, fp):

def test_wrapped(self, fp):
assert fp.dummy("foo", "bar") == 1
assert fp.dummy.__wrapped__("foo", "bar") == (1, 2, 3)
assert fp.dummy.__wrapped__("foo", "bar") == (True, 4, 2)

def test_bitvector(self, fp):
bv = fp.bitvector(ligand_mol, protein_mol["ASP129.A"])
Expand Down

0 comments on commit 6e5162b

Please sign in to comment.