From 6e5162bf10a2f69b67d2d5b9e01dc8512acf4b9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Bouysset?= Date: Fri, 11 Jun 2021 15:13:49 +0200 Subject: [PATCH] hotfix custom interaction class error msg --- CHANGELOG.md | 9 +++++++++ docs/notebooks/how-to.ipynb | 23 ++--------------------- prolif/fingerprint.py | 19 ++++++++++++++----- tests/test_fingerprint.py | 35 ++++++++++++++++++++++------------- 4 files changed, 47 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 984fb96..346952e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed ### Fixed +## [0.3.3] - 2021-06-11 +### Changed +- Custom interactions must return three values: a boolean for the interaction, + and the indices of residue atoms responsible for the interaction +### Fixed +- Custom interactions that only returned a single value instead of three would + raise an uninformative error message + + ## [0.3.2] - 2021-06-11 ### Added - LigNetwork: an interaction diagram with atomistic details for the ligand and diff --git a/docs/notebooks/how-to.ipynb b/docs/notebooks/how-to.ipynb index 9625938..ffed019 100644 --- a/docs/notebooks/how-to.ipynb +++ b/docs/notebooks/how-to.ipynb @@ -222,9 +222,9 @@ "\n", "This method takes exactly two positional arguments (and as many named arguments as you need): a ligand Residue or Molecule and a protein Residue or Molecule (in this order).\n", "\n", - "* **Return value(s) for the `detect` method**\n", + "* **Return values for the `detect` method**\n", "\n", - "There are two possibilities here, depending on whether or not you want to access the indices of atoms responsible for the interaction. If you don't need this information, just return `True` if the interaction is detected, `False` otherwise. If you need to access atomic indices, you must return the following items in this order: \n", + "You must return the following items in this order: \n", "\n", " * `True` or `False` for the detection of the interaction\n", " * The index of the ligand atom, or None if not detected\n", @@ -239,25 +239,6 @@ "source": [ "from scipy.spatial import distance_matrix\n", "\n", - "# without atom indices\n", - "class CloseContact(plf.interactions.Interaction):\n", - " def detect(self, res1, res2, threshold=2.0):\n", - " dist_matrix = distance_matrix(res1.xyz, res2.xyz)\n", - " if (dist_matrix <= threshold).any():\n", - " return True\n", - " return False\n", - "\n", - "fp = plf.Fingerprint()\n", - "fp.closecontact(lmol, pmol[\"ASP129.A\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# with atom indices\n", "class CloseContact(plf.interactions.Interaction):\n", " def detect(self, res1, res2, threshold=2.0):\n", " dist_matrix = distance_matrix(res1.xyz, res2.xyz)\n", diff --git a/prolif/fingerprint.py b/prolif/fingerprint.py index b946fd4..e8727c0 100644 --- a/prolif/fingerprint.py +++ b/prolif/fingerprint.py @@ -30,7 +30,12 @@ def _return_first_element(f): """Modifies the return signature of a function by forcing it to return - only the first element if multiple values were returned. + only the first element when multiple values are returned + + Raises + ------ + TypeError + If the function doesn't return three values Notes ----- @@ -57,10 +62,13 @@ def _return_first_element(f): def wrapper(*args, **kwargs): results = f(*args, **kwargs) try: - value, *rest = results - except TypeError: - value = results - return value + bool_, lig_idx, prot_idx = results + except (TypeError, ValueError): + raise TypeError( + "Incorrect function signature: the interaction class must " + "return 3 values (boolean, int, int)" + ) from None + return bool_ return wrapper @@ -218,6 +226,7 @@ def bitvector_atoms(self, res1, res2): A list containing indices for the protein atoms responsible for each interaction + .. versionchanged:: 0.3.2 Atom indices are returned as two separate lists instead of a single list of tuples diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py index d00ddd5..7717744 100644 --- a/tests/test_fingerprint.py +++ b/tests/test_fingerprint.py @@ -13,23 +13,32 @@ class Dummy(Interaction): def detect(self, res1, res2): - return 1, 2, 3 + return True, 4, 2 -def func_return_single_val(): - return 0 +def return_value(*args): + return args if len(args) > 1 else args[0] def test_wrapper_return(): - foo = Dummy().detect - bar = _return_first_element(foo) - assert foo("foo", "bar") == (1, 2, 3) - assert bar("foo", "bar") == 1 - assert bar.__wrapped__("foo", "bar") == (1, 2, 3) - baz = _return_first_element(func_return_single_val) - assert baz() == 0 - assert baz.__wrapped__() == 0 - + detect = Dummy().detect + mod = _return_first_element(detect) + assert detect("foo", "bar") == (True, 4, 2) + assert mod("foo", "bar") is True + assert mod.__wrapped__("foo", "bar") == (True, 4, 2) + + +@pytest.mark.parametrize("returned", [ + True, + (True,), + (True, 4) +]) +def test_wrapper_incorrect_return(returned): + mod = _return_first_element(return_value) + assert mod.__wrapped__(returned) == returned + with pytest.raises(TypeError, + match="Incorrect function signature"): + mod(returned) class TestFingerprint: @pytest.fixture @@ -60,7 +69,7 @@ def test_n_interactions(self, fp): def test_wrapped(self, fp): assert fp.dummy("foo", "bar") == 1 - assert fp.dummy.__wrapped__("foo", "bar") == (1, 2, 3) + assert fp.dummy.__wrapped__("foo", "bar") == (True, 4, 2) def test_bitvector(self, fp): bv = fp.bitvector(ligand_mol, protein_mol["ASP129.A"])