diff --git a/docs/api.md b/docs/api.md index 1fa80ef..85e286c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -11,6 +11,7 @@ .. autosummary:: :toctree: generated + tl.ancestral_states tl.clades tl.sort ``` diff --git a/pyproject.toml b/pyproject.toml index 56a4a24..e1b15b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "numpy", "pandas", "session-info", + "scipy", ] [project.optional-dependencies] diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index a35b3e6..ef9172d 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -101,7 +101,11 @@ def branches( if mcolors.is_color_like(color): kwargs.update({"color": color}) elif isinstance(color, str): - color_data = get_keyed_edge_data(trees, color) + color_data = get_keyed_edge_data(tdata, color, tree_keys)[color] + print(color_data) + if len(color_data) == 0: + raise ValueError(f"Key {color!r} is not present in any edge.") + color_data.index = color_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}") if color_data.dtype.kind in ["i", "f"]: norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max()) cmap = plt.get_cmap(cmap) @@ -125,7 +129,10 @@ def branches( if isinstance(linewidth, (int, float)): kwargs.update({"linewidth": linewidth}) elif isinstance(linewidth, str): - linewidth_data = get_keyed_edge_data(trees, linewidth) + linewidth_data = get_keyed_edge_data(tdata, linewidth, tree_keys)[linewidth] + if len(linewidth_data) == 0: + raise ValueError(f"Key {linewidth!r} is not present in any edge.") + linewidth_data.index = linewidth_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}") if linewidth_data.dtype.kind in ["i", "f"]: linewidths = [linewidth_data[edge] if edge in linewidth_data.index else na_linewidth for edge in edges] kwargs.update({"linewidth": linewidths}) @@ -271,7 +278,10 @@ def nodes( if mcolors.is_color_like(color): kwargs.update({"color": color}) elif isinstance(color, str): - color_data = get_keyed_node_data(trees, color) + color_data = get_keyed_node_data(tdata, color, tree_keys)[color] + if len(color_data) == 0: + raise ValueError(f"Key {color!r} is not present in any node.") + color_data.index = color_data.index.map("-".join) if color_data.dtype.kind in ["i", "f"]: if not vmin: vmin = color_data.min() @@ -298,7 +308,10 @@ def nodes( if isinstance(size, (int, float)): kwargs.update({"s": size}) elif isinstance(size, str): - size_data = get_keyed_node_data(trees, size) + size_data = get_keyed_node_data(tdata, size, tree_keys)[size] + if len(size_data) == 0: + raise ValueError(f"Key {size!r} is not present in any node.") + size_data.index = size_data.index.map("-".join) sizes = [size_data[node] if node in size_data.index else na_size for node in nodes] kwargs.update({"s": sizes}) else: @@ -307,7 +320,10 @@ def nodes( if style in mmarkers.MarkerStyle.markers: kwargs.update({"marker": style}) elif isinstance(style, str): - style_data = get_keyed_node_data(trees, style) + style_data = get_keyed_node_data(tdata, style, tree_keys)[style] + if len(style_data) == 0: + raise ValueError(f"Key {style!r} is not present in any node.") + style_data.index = style_data.index.map("-".join) mmap = _get_categorical_markers(tdata, style, style_data, markers) styles = [mmap[style_data[node]] if node in style_data.index else na_style for node in nodes] for style in set(styles): diff --git a/src/pycea/tl/__init__.py b/src/pycea/tl/__init__.py index a3522c7..7518780 100644 --- a/src/pycea/tl/__init__.py +++ b/src/pycea/tl/__init__.py @@ -1,2 +1,3 @@ from .clades import clades from .sort import sort +from .ancestral_states import ancestral_states \ No newline at end of file diff --git a/src/pycea/tl/ancestral_states.py b/src/pycea/tl/ancestral_states.py new file mode 100755 index 0000000..bb5fc2d --- /dev/null +++ b/src/pycea/tl/ancestral_states.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +from collections.abc import Sequence + +import networkx as nx +import numpy as np +import pandas as pd +import treedata as td + +from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_root, get_trees + + +def _most_common(arr): + """Finds the most common element in a list.""" + unique_values, counts = np.unique(arr, return_counts=True) + most_common_index = np.argmax(counts) + return unique_values[most_common_index] + + +def _get_node_value(tree, node, key, index): + """Gets the value of a node attribute.""" + if key in tree.nodes[node]: + if index is not None: + return tree.nodes[node][key][index] + else: + return tree.nodes[node][key] + else: + return None + + +def _set_node_value(tree, node, key, value, index): + """Sets the value of a node attribute.""" + if index is not None: + tree.nodes[node][key][index] = value + else: + tree.nodes[node][key] = value + + +def _reconstruct_fitch_hartigan(tree, key, missing="-1", index=None): + """Reconstructs ancestral states using the Fitch-Hartigan algorithm.""" + + # Recursive function to calculate the downpass + def downpass(node): + # Base case: leaf + if tree.out_degree(node) == 0: + value = _get_node_value(tree, node, key, index) + if value == missing: + tree.nodes[node]["value_set"] = missing + else: + tree.nodes[node]["value_set"] = {value} + # Recursive case: internal node + else: + value_sets = [] + for child in tree.successors(node): + downpass(child) + value_set = tree.nodes[child]["value_set"] + if value_set != missing: + value_sets.append(value_set) + if len(value_sets) > 0: + intersection = set.intersection(*value_sets) + if intersection: + tree.nodes[node]["value_set"] = intersection + else: + tree.nodes[node]["value_set"] = set.union(*value_sets) + else: + tree.nodes[node]["value_set"] = missing + + # Recursive function to calculate the uppass + def uppass(node, parent_state=None): + value = _get_node_value(tree, node, key, index) + if value is None: + if parent_state and parent_state in tree.nodes[node]["value_set"]: + value = parent_state + else: + value = min(tree.nodes[node]["value_set"]) + _set_node_value(tree, node, key, value, index) + elif value == missing: + value = parent_state + _set_node_value(tree, node, key, value, index) + for child in tree.successors(node): + uppass(child, value) + + # Run the algorithm + root = get_root(tree) + downpass(root) + uppass(root) + # Clean up + for node in tree.nodes: + if "value_set" in tree.nodes[node]: + del tree.nodes[node]["value_set"] + + +def _reconstruct_sankoff(tree, key, costs, missing="-1", index=None): + """Reconstructs ancestral states using the Sankoff algorithm.""" + + # Recursive function to calculate the Sankoff scores + def sankoff_scores(node): + # Base case: leaf + if tree.out_degree(node) == 0: + leaf_value = _get_node_value(tree, node, key, index) + if leaf_value == missing: + return {value: 0 for value in alphabet} + else: + return {value: 0 if value == leaf_value else float("inf") for value in alphabet} + # Recursive case: internal node + else: + scores = {value: 0 for value in alphabet} + pointers = {value: {} for value in alphabet} + for child in tree.successors(node): + child_scores = sankoff_scores(child) + for value in alphabet: + min_cost, min_value = float("inf"), None + for child_value in alphabet: + cost = child_scores[child_value] + costs.loc[value, child_value] + if cost < min_cost: + min_cost, min_value = cost, child_value + scores[value] += min_cost + pointers[value][child] = min_value + tree.nodes[node]["_pointers"] = pointers + return scores + + # Recursive function to traceback the Sankoff scores + def traceback(node, parent_value=None): + for child in tree.successors(node): + child_value = tree.nodes[node]["_pointers"][parent_value][child] + _set_node_value(tree, child, key, child_value, index) + traceback(child, child_value) + + # Get scores + root = get_root(tree) + alphabet = set(costs.index) + root_scores = sankoff_scores(root) + # Reconstruct ancestral states + root_value = min(root_scores, key=root_scores.get) + _set_node_value(tree, root, key, root_value, index) + traceback(root, root_value) + # Clean up + for node in tree.nodes: + if "_pointers" in tree.nodes[node]: + del tree.nodes[node]["_pointers"] + + +def _reconstruct_mean(tree, key, index): + """Reconstructs ancestral by averaging the values of the children.""" + + def subtree_mean(node): + if tree.out_degree(node) == 0: + return _get_node_value(tree, node, key, index), 1 + else: + values, weights = [], [] + for child in tree.successors(node): + child_value, child_n = subtree_mean(child) + values.append(child_value) + weights.append(child_n) + mean_value = np.average(values, weights=weights) + _set_node_value(tree, node, key, mean_value, index) + return mean_value, sum(weights) + + root = get_root(tree) + subtree_mean(root) + + +def _reconstruct_list(tree, key, sum_func, index): + """Reconstructs ancestral states by concatenating the values of the children.""" + + def subtree_list(node): + if tree.out_degree(node) == 0: + return [_get_node_value(tree, node, key, index)] + else: + values = [] + for child in tree.successors(node): + values.extend(subtree_list(child)) + _set_node_value(tree, node, key, sum_func(values), index) + return values + + root = get_root(tree) + subtree_list(root) + + +def _ancestral_states(tree, key, method="mean", costs=None, missing=None, default=None, index=None): + """Reconstructs ancestral states for a given attribute using a given method""" + if method == "sankoff": + if costs is None: + raise ValueError("Costs matrix must be provided for Sankoff algorithm.") + _reconstruct_sankoff(tree, key, costs, missing, index) + elif method == "fitch_hartigan": + _reconstruct_fitch_hartigan(tree, key, missing, index) + elif method == "mean": + _reconstruct_mean(tree, key, index) + elif method == "mode": + _reconstruct_list(tree, key, _most_common, index) + elif callable(method): + _reconstruct_list(tree, key, method, index) + else: + raise ValueError(f"Method {method} not recognized.") + + +def ancestral_states( + tdata: td.TreeData, + keys: str | Sequence[str], + method: str = "mean", + missing_state: str = "-1", + default_state: str = "0", + costs: pd.DataFrame = None, + tree: str | Sequence[str] | None = None, + copy: bool = False, +) -> None: + """Reconstructs ancestral states for an attribute. + + Parameters + ---------- + tdata + TreeData object. + keys + One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to reconstruct. + method + Method to reconstruct ancestral states. One of "mean", "mode", "fitch_hartigan", "sankoff", + or any function that takes a list of values and returns a single value. + missing_state + The state to consider as missing data. + default_state + The expected state for the root node. + costs + A pd.DataFrame with the costs of changing states (from rows to columns). + tree + The `obst` key or keys of the trees to use. If `None`, all trees are used. + copy + If True, returns a pd.DataFrame with ancestral states. + """ + if isinstance(keys, str): + keys = [keys] + tree_keys = tree + trees = get_trees(tdata, tree_keys) + for _, tree in trees.items(): + data, is_array = get_keyed_obs_data(tdata, keys) + dtypes = {dtype.kind for dtype in data.dtypes} + # Check data type + if dtypes.intersection({"i", "f"}): + if method in ["fitch_hartigan", "sankoff"]: + raise ValueError(f"Method {method} requires categorical data.") + if dtypes.intersection({"O", "S"}): + if method in ["mean"]: + raise ValueError(f"Method {method} requires numerical data.") + # If array add to tree as list + if is_array: + length = data.shape[1] + node_attrs = data.apply(lambda row: list(row), axis=1).to_dict() + for node in tree.nodes: + if node not in node_attrs: + node_attrs[node] = [None] * length + nx.set_node_attributes(tree, node_attrs, keys[0]) + for index in range(length): + _ancestral_states(tree, keys[0], method, costs, missing_state, default_state, index) + # If column add to tree as scalar + else: + for key in keys: + nx.set_node_attributes(tree, data[key].to_dict(), key) + _ancestral_states(tree, key, method, missing_state, default_state) + if copy: + return get_keyed_node_data(tdata, keys, tree_keys) diff --git a/src/pycea/tl/sort.py b/src/pycea/tl/sort.py index ef64d76..6930f65 100755 --- a/src/pycea/tl/sort.py +++ b/src/pycea/tl/sort.py @@ -21,7 +21,7 @@ def _sort_tree(tree, key, reverse=False): def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequence[str] | None = None) -> None: - """Sorts the children of each internal node in a tree based on a given key. + """Reorders branches based on a given key. Parameters ---------- diff --git a/src/pycea/utils.py b/src/pycea/utils.py index f4959bf..7648411 100755 --- a/src/pycea/utils.py +++ b/src/pycea/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Sequence, Mapping +from collections.abc import Mapping, Sequence import networkx as nx import pandas as pd @@ -24,46 +24,46 @@ def get_leaves(tree: nx.DiGraph): return [node for node in nx.dfs_postorder_nodes(tree, get_root(tree)) if tree.out_degree(node) == 0] -def get_keyed_edge_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) -> pd.Series: +def get_keyed_edge_data( + tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None +) -> pd.DataFrame: """Gets edge data for a given key from a tree or set of trees.""" - if isinstance(tree, nx.DiGraph): - trees = {"": tree} - sep = "" - else: - trees = tree - sep = "-" - edge_data = {} + if isinstance(tree_keys, str): + tree_keys = [tree_keys] + if isinstance(keys, str): + keys = [keys] + trees = get_trees(tdata, tree_keys) + data = [] for name, tree in trees.items(): - edge_data.update( - { - (f"{name}{sep}{parent}", f"{name}{sep}{child}"): data.get(key) - for parent, child, data in tree.edges(data=True) - if key in data and data[key] is not None - }) - if len(edge_data) == 0: - raise ValueError(f"Key {key!r} is not present in any edge.") - return pd.Series(edge_data, name=key) + edge_data = {key: nx.get_edge_attributes(tree, key) for key in keys} + edge_data = pd.DataFrame(edge_data) + edge_data["tree"] = name + edge_data["edge"] = edge_data.index + data.append(edge_data) + data = pd.concat(data) + data = data.set_index(["tree", "edge"]) + return data -def get_keyed_node_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) -> pd.Series: +def get_keyed_node_data( + tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None +) -> pd.DataFrame: """Gets node data for a given key a tree or set of trees.""" - if isinstance(tree, nx.DiGraph): - trees = {"": tree} - sep = "" - else: - trees = tree - sep = "-" - node_data = {} + if isinstance(tree_keys, str): + tree_keys = [tree_keys] + if isinstance(keys, str): + keys = [keys] + trees = get_trees(tdata, tree_keys) + data = [] for name, tree in trees.items(): - node_data.update( - { - f"{name}{sep}{node}": data.get(key) - for node, data in tree.nodes(data=True) - if key in data and data[key] is not None - }) - if len(node_data) == 0: - raise ValueError(f"Key {key!r} is not present in any node.") - return pd.Series(node_data, name=key) + tree_data = {key: nx.get_node_attributes(tree, key) for key in keys} + tree_data = pd.DataFrame(tree_data) + tree_data["tree"] = name + data.append(tree_data) + data = pd.concat(data) + data["node"] = data.index + data = data.set_index(["tree", "node"]) + return data def get_keyed_obs_data(tdata: td.TreeData, keys: Sequence[str], layer: str = None) -> pd.DataFrame: @@ -103,7 +103,6 @@ def get_keyed_obs_data(tdata: td.TreeData, keys: Sequence[str], layer: str = Non data.columns = keys elif array_keys: data = pd.DataFrame(data[0], index=tdata.obs_names) - if data.shape[0] == data.shape[1]: data.columns = tdata.obs_names return data, array_keys diff --git a/tests/test_ancestral_states.py b/tests/test_ancestral_states.py new file mode 100755 index 0000000..085a3e4 --- /dev/null +++ b/tests/test_ancestral_states.py @@ -0,0 +1,100 @@ +import networkx as nx +import numpy as np +import pandas as pd +import pytest +import treedata as td + +from pycea.tl.ancestral_states import ancestral_states + + +@pytest.fixture +def tdata(): + tree1 = nx.DiGraph([("root", "B"), ("root", "C"), ("C", "D"), ("C", "E")]) + tree2 = nx.DiGraph([("root", "F")]) + spatial = np.array([[0, 4], [1, 1], [2, 1], [4, 4]]) + characters = np.array([["-1", "0"], ["1", "1"], ["2", "-1"], ["1", "2"]]) + tdata = td.TreeData( + obs=pd.DataFrame( + {"value": [0, 0, 3, 2], "str_value": ["0", "0", "3", "2"], "with_missing": [0, np.nan, 3, 2]}, + index=["B", "D", "E", "F"], + ), + obst={"tree1": tree1, "tree2": tree2}, + obsm={"spatial": spatial, "characters": characters}, + ) + yield tdata + + +def test_ancestral_states(tdata): + # Mean + states = ancestral_states(tdata, "value", method="mean", copy=True) + assert tdata.obst["tree1"].nodes["root"]["value"] == 1 + assert tdata.obst["tree1"].nodes["C"]["value"] == 1.5 + assert states["value"].tolist() == [1, 0, 1.5, 0, 3, 2, 2] + # Median + states = ancestral_states(tdata, "value", method=np.median, copy=True) + assert tdata.obst["tree1"].nodes["root"]["value"] == 0 + # Mode + ancestral_states(tdata, "str_value", method="mode", copy=False, tree="tree1") + for node in tdata.obst["tree1"].nodes: + print(node, tdata.obst["tree1"].nodes[node]) + assert tdata.obst["tree1"].nodes["root"]["str_value"] == "0" + + +def test_ancestral_states_array(tdata): + # Mean + states = ancestral_states(tdata, "spatial", method="mean", copy=True) + assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 2.0] + assert tdata.obst["tree1"].nodes["C"]["spatial"] == [1.5, 1.0] + assert states.loc[("tree1", "root"), "spatial"] == [1.0, 2.0] + # Median + states = ancestral_states(tdata, "spatial", method=np.median, copy=True) + assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 1.0] + + +def test_ancestral_states_missing(tdata): + # Mean + states = ancestral_states(tdata, "with_missing", method=np.nanmean, copy=True) + assert tdata.obst["tree1"].nodes["root"]["with_missing"] == 1.5 + assert tdata.obst["tree1"].nodes["C"]["with_missing"] == 3 + assert states.loc[("tree1", "root"), "with_missing"] == 1.5 + + +def test_ancestral_state_fitch(tdata): + states = ancestral_states(tdata, "characters", method="fitch_hartigan", copy=True) + assert tdata.obst["tree1"].nodes["root"]["characters"] == ["1", "0"] + assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"] + assert states.loc[("tree1", "root"), "characters"] == ["1", "0"] + + +def test_ancestral_states_sankoff(tdata): + costs = pd.DataFrame( + [[0, 1, 2], [10, 0, 10], [10, 10, 0]], + index=["0", "1", "2"], + columns=["0", "1", "2"], + ) + states = ancestral_states(tdata, "characters", method="sankoff", costs=costs, copy=True) + assert tdata.obst["tree1"].nodes["root"]["characters"] == ["0", "0"] + assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"] + assert states.loc[("tree1", "root"), "characters"] == ["0", "0"] + costs = pd.DataFrame( + [[0, 10, 10], [1, 0, 2], [2, 1, 0]], + index=["0", "1", "2"], + columns=["0", "1", "2"], + ) + states = ancestral_states(tdata, "characters", method="sankoff", costs=costs, copy=True) + assert tdata.obst["tree1"].nodes["root"]["characters"] == ["2", "1"] + + +def test_ancestral_states_invalid(tdata): + with pytest.raises(ValueError): + ancestral_states(tdata, "characters", method="sankoff") + with pytest.raises(ValueError): + ancestral_states(tdata, "characters", method="sankoff", costs=pd.DataFrame()) + with pytest.raises(ValueError): + ancestral_states(tdata, "bad", method="mean") + with pytest.raises(ValueError): + ancestral_states(tdata, "value", method="bad") + with pytest.raises(ValueError): + ancestral_states(tdata, "value", method="fitch_hartigan", copy=False) + with pytest.raises(ValueError): + ancestral_states(tdata, "str_value", method="mean", copy=False) diff --git a/tests/test_utils.py b/tests/test_utils.py index 3f401d5..22f8491 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_root, get_leaves +from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_leaves, get_root @pytest.fixture @@ -31,20 +31,14 @@ def test_get_leaves(tree): assert get_leaves(nx.DiGraph()) == [] -def test_get_keyed_edge_data(tree): - result = get_keyed_edge_data(tree, "weight") - expected_keys = [("A", "B"), ("B", "D"), ("C", "E")] - expected_values = [5, 3, 4] - assert all(result[key] == value for key, value in zip(expected_keys, expected_values)) - assert ("A", "C") not in result +def test_get_keyed_edge_data(tdata): + data = get_keyed_edge_data(tdata, ["length", "clade"]) + assert data.columns.tolist() == ["length", "clade"] -def test_get_keyed_node_data(tree): - result = get_keyed_node_data(tree, "value") - expected_keys = ["A", "B", "D", "E"] - expected_values = [1, 2, 4, 5] - assert all(result[key] == value for key, value in zip(expected_keys, expected_values)) - assert "C" not in result +def test_get_keyed_node_data(tdata): + data = get_keyed_node_data(tdata, ["x", "y", "clade"]) + assert data.columns.tolist() == ["x", "y", "clade"] def test_get_keyed_obs_data_valid_keys(tdata): @@ -70,3 +64,7 @@ def test_get_keyed_obs_data_invalid_keys(tdata): get_keyed_obs_data(tdata, ["clade", "x", "0", "invalid_key"]) with pytest.raises(ValueError): get_keyed_obs_data(tdata, ["clade", "spatial_distance"]) + + +if __name__ == "__main__": + pytest.main(["-v", __file__])