From bae807b0ddfac678fad79d4ac0c3ad8a59328d63 Mon Sep 17 00:00:00 2001 From: colganwi Date: Fri, 24 May 2024 19:08:25 -0400 Subject: [PATCH 1/4] ancestral states --- docs/api.md | 1 + pyproject.toml | 1 + src/pycea/tl/__init__.py | 1 + src/pycea/tl/ancestral_states.py | 90 ++++++++++++++++++++++++++++++++ src/pycea/tl/sort.py | 2 +- tests/test_ancestral_states.py | 31 +++++++++++ 6 files changed, 125 insertions(+), 1 deletion(-) create mode 100755 src/pycea/tl/ancestral_states.py create mode 100755 tests/test_ancestral_states.py diff --git a/docs/api.md b/docs/api.md index 1fa80ef..85e286c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -11,6 +11,7 @@ .. autosummary:: :toctree: generated + tl.ancestral_states tl.clades tl.sort ``` diff --git a/pyproject.toml b/pyproject.toml index 56a4a24..e1b15b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "numpy", "pandas", "session-info", + "scipy", ] [project.optional-dependencies] diff --git a/src/pycea/tl/__init__.py b/src/pycea/tl/__init__.py index a3522c7..7518780 100644 --- a/src/pycea/tl/__init__.py +++ b/src/pycea/tl/__init__.py @@ -1,2 +1,3 @@ from .clades import clades from .sort import sort +from .ancestral_states import ancestral_states \ No newline at end of file diff --git a/src/pycea/tl/ancestral_states.py b/src/pycea/tl/ancestral_states.py new file mode 100755 index 0000000..38f0ddc --- /dev/null +++ b/src/pycea/tl/ancestral_states.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from collections.abc import Sequence + +import networkx as nx +import numpy as np +import pandas as pd +import treedata as td + +from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_trees + + +def _most_common(arr): + """Finds the most common element in a list.""" + unique_values, counts = np.unique(arr, return_counts=True) + most_common_index = np.argmax(counts) + return unique_values[most_common_index] + + +def _ancestral_states(tree, key, method="mean"): + """Finds the ancestral state of a node in a tree.""" + # Get summation function + if method == "mean": + sum_func = np.mean + elif method == "median": + sum_func = np.median + elif method == "mode": + sum_func = _most_common + else: + raise ValueError(f"Method {method} not recognized.") + # Get aggregation function + if method in ["mean", "median", "mode"]: + agg_func = np.concatenate + # infer ancestral states + for node in nx.dfs_postorder_nodes(tree): + if tree.out_degree(node) == 0: + tree.nodes[node]["_message"] = np.array([tree.nodes[node][key]]) + else: + subtree_values = agg_func([tree.nodes[child]["_message"] for child in tree.successors(node)]) + tree.nodes[node]["_message"] = subtree_values + tree.nodes[node][key] = sum_func(subtree_values) + # remove messages + for node in tree.nodes: + del tree.nodes[node]["_message"] + + +def ancestral_states( + tdata: td.TreeData, + keys: str | Sequence[str], + method: str = "mean", + tree: str | Sequence[str] | None = None, + copy: bool = False, +) -> None: + """Reconstructs ancestral states for an attribute. + + Parameters + ---------- + tdata + TreeData object. + keys + One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to reconstruct. + method + Method to reconstruct ancestral states. One of "mean", "median", or "mode". + tree + The `obst` key or keys of the trees to use. If `None`, all trees are used. + copy + If True, returns a pd.DataFrame with ancestral states. + """ + if isinstance(keys, str): + keys = [keys] + tree_keys = tree + trees = get_trees(tdata, tree_keys) + for _, tree in trees.items(): + data, _ = get_keyed_obs_data(tdata, keys) + for key in keys: + nx.set_node_attributes(tree, data[key].to_dict(), key) + _ancestral_states(tree, key, method) + if copy: + states = [] + for name, tree in trees.items(): + tree_states = [] + for key in keys: + data = get_keyed_node_data(tree, key) + tree_states.append(data) + tree_states = pd.concat(tree_states, axis=1) + tree_states["tree"] = name + states.append(tree_states) + states = pd.concat(states) + states["node"] = states.index + return states.reset_index(drop=True) diff --git a/src/pycea/tl/sort.py b/src/pycea/tl/sort.py index ef64d76..6930f65 100755 --- a/src/pycea/tl/sort.py +++ b/src/pycea/tl/sort.py @@ -21,7 +21,7 @@ def _sort_tree(tree, key, reverse=False): def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequence[str] | None = None) -> None: - """Sorts the children of each internal node in a tree based on a given key. + """Reorders branches based on a given key. Parameters ---------- diff --git a/tests/test_ancestral_states.py b/tests/test_ancestral_states.py new file mode 100755 index 0000000..41f58f2 --- /dev/null +++ b/tests/test_ancestral_states.py @@ -0,0 +1,31 @@ +import networkx as nx +import pandas as pd +import pytest +import treedata as td + +from pycea.tl.ancestral_states import ancestral_states + + +@pytest.fixture +def tdata(): + tree1 = nx.DiGraph([("root", "B"), ("root", "C"), ("C", "D"), ("C", "E")]) + tree2 = nx.DiGraph([("root", "F")]) + tdata = td.TreeData( + obs=pd.DataFrame({"value": [0, 0, 3, 2], "str_value": ["0", "0", "3", "2"]}, index=["B", "D", "E", "F"]), + obst={"tree1": tree1, "tree2": tree2}, + ) + yield tdata + + +def test_ancestral_states(tdata): + # Mean + states = ancestral_states(tdata, "value", method="mean", copy=True) + assert tdata.obst["tree1"].nodes["root"]["value"] == 1 + assert tdata.obst["tree1"].nodes["C"]["value"] == 1.5 + assert states["value"].tolist() == [1, 0, 1.5, 0, 3, 2, 2] + # Median + states = ancestral_states(tdata, "value", method="median", copy=True) + assert tdata.obst["tree1"].nodes["root"]["value"] == 0 + # Mode + ancestral_states(tdata, "str_value", method="mode", copy=False, tree="tree1") + assert tdata.obst["tree1"].nodes["root"]["str_value"] == "0" From 5216b5edc537b6eacfcdf0e0e2cb61e509269a48 Mon Sep 17 00:00:00 2001 From: colganwi Date: Mon, 27 May 2024 19:11:25 -0400 Subject: [PATCH 2/4] sankoff --- src/pycea/tl/ancestral_states.py | 235 +++++++++++++++++++++++++++---- src/pycea/utils.py | 9 +- tests/test_ancestral_states.py | 75 +++++++++- 3 files changed, 286 insertions(+), 33 deletions(-) diff --git a/src/pycea/tl/ancestral_states.py b/src/pycea/tl/ancestral_states.py index 38f0ddc..de57ed1 100755 --- a/src/pycea/tl/ancestral_states.py +++ b/src/pycea/tl/ancestral_states.py @@ -7,7 +7,7 @@ import pandas as pd import treedata as td -from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_trees +from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_root, get_trees def _most_common(arr): @@ -17,37 +17,191 @@ def _most_common(arr): return unique_values[most_common_index] -def _ancestral_states(tree, key, method="mean"): - """Finds the ancestral state of a node in a tree.""" - # Get summation function - if method == "mean": - sum_func = np.mean - elif method == "median": - sum_func = np.median - elif method == "mode": - sum_func = _most_common +def _get_node_value(tree, node, key, index): + """Gets the value of a node attribute.""" + if key in tree.nodes[node]: + if index is not None: + return tree.nodes[node][key][index] + else: + return tree.nodes[node][key] else: - raise ValueError(f"Method {method} not recognized.") - # Get aggregation function - if method in ["mean", "median", "mode"]: - agg_func = np.concatenate - # infer ancestral states - for node in nx.dfs_postorder_nodes(tree): + return None + + +def _set_node_value(tree, node, key, value, index): + """Sets the value of a node attribute.""" + if index is not None: + tree.nodes[node][key][index] = value + else: + tree.nodes[node][key] = value + + +def _reconstruct_fitch_hartigan(tree, key, missing="-1", index=None): + """Reconstructs ancestral states using the Fitch-Hartigan algorithm.""" + + # Recursive function to calculate the downpass + def downpass(node): + # Base case: leaf if tree.out_degree(node) == 0: - tree.nodes[node]["_message"] = np.array([tree.nodes[node][key]]) + value = _get_node_value(tree, node, key, index) + if value == missing: + tree.nodes[node]["value_set"] = missing + else: + tree.nodes[node]["value_set"] = {value} + # Recursive case: internal node else: - subtree_values = agg_func([tree.nodes[child]["_message"] for child in tree.successors(node)]) - tree.nodes[node]["_message"] = subtree_values - tree.nodes[node][key] = sum_func(subtree_values) - # remove messages + value_sets = [] + for child in tree.successors(node): + downpass(child) + value_set = tree.nodes[child]["value_set"] + if value_set != missing: + value_sets.append(value_set) + if len(value_sets) > 0: + intersection = set.intersection(*value_sets) + if intersection: + tree.nodes[node]["value_set"] = intersection + else: + tree.nodes[node]["value_set"] = set.union(*value_sets) + else: + tree.nodes[node]["value_set"] = missing + + # Recursive function to calculate the uppass + def uppass(node, parent_state=None): + value = _get_node_value(tree, node, key, index) + if value is None: + if parent_state and parent_state in tree.nodes[node]["value_set"]: + value = parent_state + else: + value = min(tree.nodes[node]["value_set"]) + _set_node_value(tree, node, key, value, index) + elif value == missing: + value = parent_state + _set_node_value(tree, node, key, value, index) + for child in tree.successors(node): + uppass(child, value) + + # Run the algorithm + root = get_root(tree) + downpass(root) + uppass(root) + # Clean up + for node in tree.nodes: + if "value_set" in tree.nodes[node]: + del tree.nodes[node]["value_set"] + + +def _reconstruct_sankoff(tree, key, costs, missing="-1", index=None): + """Reconstructs ancestral states using the Sankoff algorithm.""" + + # Recursive function to calculate the Sankoff scores + def sankoff_scores(node): + # Base case: leaf + if tree.out_degree(node) == 0: + leaf_value = _get_node_value(tree, node, key, index) + if leaf_value == missing: + return {value: 0 for value in alphabet} + else: + return {value: 0 if value == leaf_value else float("inf") for value in alphabet} + # Recursive case: internal node + else: + scores = {value: 0 for value in alphabet} + pointers = {value: {} for value in alphabet} + for child in tree.successors(node): + child_scores = sankoff_scores(child) + for value in alphabet: + min_cost, min_value = float("inf"), None + for child_value in alphabet: + cost = child_scores[child_value] + costs.loc[value, child_value] + if cost < min_cost: + min_cost, min_value = cost, child_value + scores[value] += min_cost + pointers[value][child] = min_value + tree.nodes[node]["_pointers"] = pointers + return scores + + # Recursive function to traceback the Sankoff scores + def traceback(node, parent_value=None): + for child in tree.successors(node): + child_value = tree.nodes[node]["_pointers"][parent_value][child] + _set_node_value(tree, child, key, child_value, index) + traceback(child, child_value) + + # Get scores + root = get_root(tree) + alphabet = set(costs.index) + root_scores = sankoff_scores(root) + # Reconstruct ancestral states + root_value = min(root_scores, key=root_scores.get) + _set_node_value(tree, root, key, root_value, index) + traceback(root, root_value) + # Clean up for node in tree.nodes: - del tree.nodes[node]["_message"] + if "_pointers" in tree.nodes[node]: + del tree.nodes[node]["_pointers"] + + +def _reconstruct_mean(tree, key, index): + """Reconstructs ancestral by averaging the values of the children.""" + + def subtree_mean(node): + if tree.out_degree(node) == 0: + return _get_node_value(tree, node, key, index), 1 + else: + values, weights = [], [] + for child in tree.successors(node): + child_value, child_n = subtree_mean(child) + values.append(child_value) + weights.append(child_n) + mean_value = np.average(values, weights=weights) + _set_node_value(tree, node, key, mean_value, index) + return mean_value, sum(weights) + + root = get_root(tree) + subtree_mean(root) + + +def _reconstruct_list(tree, key, sum_func, index): + """Reconstructs ancestral states by concatenating the values of the children.""" + + def subtree_list(node): + if tree.out_degree(node) == 0: + return [_get_node_value(tree, node, key, index)] + else: + values = [] + for child in tree.successors(node): + values.extend(subtree_list(child)) + _set_node_value(tree, node, key, sum_func(values), index) + return values + + root = get_root(tree) + subtree_list(root) + + +def _ancestral_states(tree, key, method="mean", costs=None, missing=None, default=None, index=None): + """Reconstructs ancestral states for a given attribute using a given method""" + if method == "sankoff": + if costs is None: + raise ValueError("Costs matrix must be provided for Sankoff algorithm.") + _reconstruct_sankoff(tree, key, costs, missing, index) + elif method == "fitch_hartigan": + _reconstruct_fitch_hartigan(tree, key, missing, index) + elif method == "mean": + _reconstruct_mean(tree, key, index) + elif method == "mode": + _reconstruct_list(tree, key, _most_common, index) + elif callable(method): + _reconstruct_list(tree, key, method, index) + else: + raise ValueError(f"Method {method} not recognized.") def ancestral_states( tdata: td.TreeData, keys: str | Sequence[str], method: str = "mean", + missing_state: str = "-1", + default_state: str = "0", + costs: pd.DataFrame = None, tree: str | Sequence[str] | None = None, copy: bool = False, ) -> None: @@ -60,7 +214,14 @@ def ancestral_states( keys One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to reconstruct. method - Method to reconstruct ancestral states. One of "mean", "median", or "mode". + Method to reconstruct ancestral states. One of "mean", "mode", "fitch_hartigan", "sankoff", + or any function that takes a list of values and returns a single value. + missing_state + The state to consider as missing data. + default_state + The expected state for the root node. + costs + A pd.DataFrame with the costs of changing states (from rows to columns). tree The `obst` key or keys of the trees to use. If `None`, all trees are used. copy @@ -71,10 +232,30 @@ def ancestral_states( tree_keys = tree trees = get_trees(tdata, tree_keys) for _, tree in trees.items(): - data, _ = get_keyed_obs_data(tdata, keys) - for key in keys: - nx.set_node_attributes(tree, data[key].to_dict(), key) - _ancestral_states(tree, key, method) + data, is_array = get_keyed_obs_data(tdata, keys) + dtypes = {dtype.kind for dtype in data.dtypes} + # Check data type + if dtypes.intersection({"i", "f"}): + if method in ["fitch_hartigan", "sankoff"]: + raise ValueError(f"Method {method} requires categorical data.") + if dtypes.intersection({"O", "S"}): + if method in ["mean"]: + raise ValueError(f"Method {method} requires numerical data.") + # If array add to tree as list + if is_array: + length = data.shape[1] + node_attrs = data.apply(lambda row: list(row), axis=1).to_dict() + for node in tree.nodes: + if node not in node_attrs: + node_attrs[node] = [None] * length + nx.set_node_attributes(tree, node_attrs, keys[0]) + for index in range(length): + _ancestral_states(tree, keys[0], method, costs, missing_state, default_state, index) + # If column add to tree as scalar + else: + for key in keys: + nx.set_node_attributes(tree, data[key].to_dict(), key) + _ancestral_states(tree, key, method, missing_state, default_state) if copy: states = [] for name, tree in trees.items(): diff --git a/src/pycea/utils.py b/src/pycea/utils.py index f4959bf..3e77a49 100755 --- a/src/pycea/utils.py +++ b/src/pycea/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Sequence, Mapping +from collections.abc import Mapping, Sequence import networkx as nx import pandas as pd @@ -39,7 +39,8 @@ def get_keyed_edge_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) - (f"{name}{sep}{parent}", f"{name}{sep}{child}"): data.get(key) for parent, child, data in tree.edges(data=True) if key in data and data[key] is not None - }) + } + ) if len(edge_data) == 0: raise ValueError(f"Key {key!r} is not present in any edge.") return pd.Series(edge_data, name=key) @@ -60,7 +61,8 @@ def get_keyed_node_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) - f"{name}{sep}{node}": data.get(key) for node, data in tree.nodes(data=True) if key in data and data[key] is not None - }) + } + ) if len(node_data) == 0: raise ValueError(f"Key {key!r} is not present in any node.") return pd.Series(node_data, name=key) @@ -103,7 +105,6 @@ def get_keyed_obs_data(tdata: td.TreeData, keys: Sequence[str], layer: str = Non data.columns = keys elif array_keys: data = pd.DataFrame(data[0], index=tdata.obs_names) - if data.shape[0] == data.shape[1]: data.columns = tdata.obs_names return data, array_keys diff --git a/tests/test_ancestral_states.py b/tests/test_ancestral_states.py index 41f58f2..e541a05 100755 --- a/tests/test_ancestral_states.py +++ b/tests/test_ancestral_states.py @@ -1,4 +1,5 @@ import networkx as nx +import numpy as np import pandas as pd import pytest import treedata as td @@ -10,9 +11,15 @@ def tdata(): tree1 = nx.DiGraph([("root", "B"), ("root", "C"), ("C", "D"), ("C", "E")]) tree2 = nx.DiGraph([("root", "F")]) + spatial = np.array([[0, 4], [1, 1], [2, 1], [4, 4]]) + characters = np.array([["-1", "0"], ["1", "1"], ["2", "-1"], ["1", "2"]]) tdata = td.TreeData( - obs=pd.DataFrame({"value": [0, 0, 3, 2], "str_value": ["0", "0", "3", "2"]}, index=["B", "D", "E", "F"]), + obs=pd.DataFrame( + {"value": [0, 0, 3, 2], "str_value": ["0", "0", "3", "2"], "with_missing": [0, np.nan, 3, 2]}, + index=["B", "D", "E", "F"], + ), obst={"tree1": tree1, "tree2": tree2}, + obsm={"spatial": spatial, "characters": characters}, ) yield tdata @@ -24,8 +31,72 @@ def test_ancestral_states(tdata): assert tdata.obst["tree1"].nodes["C"]["value"] == 1.5 assert states["value"].tolist() == [1, 0, 1.5, 0, 3, 2, 2] # Median - states = ancestral_states(tdata, "value", method="median", copy=True) + states = ancestral_states(tdata, "value", method=np.median, copy=True) assert tdata.obst["tree1"].nodes["root"]["value"] == 0 # Mode ancestral_states(tdata, "str_value", method="mode", copy=False, tree="tree1") + for node in tdata.obst["tree1"].nodes: + print(node, tdata.obst["tree1"].nodes[node]) assert tdata.obst["tree1"].nodes["root"]["str_value"] == "0" + + +def test_ancestral_states_array(tdata): + # Mean + states = ancestral_states(tdata, "spatial", method="mean", copy=True) + print(states) + assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 2.0] + assert tdata.obst["tree1"].nodes["C"]["spatial"] == [1.5, 1.0] + assert states["spatial"][0] == [1.0, 2.0] + # Median + states = ancestral_states(tdata, "spatial", method=np.median, copy=True) + assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 1.0] + + +def test_ancestral_states_missing(tdata): + # Mean + states = ancestral_states(tdata, "with_missing", method=np.nanmean, copy=True) + print(states) + assert tdata.obst["tree1"].nodes["root"]["with_missing"] == 1.5 + assert tdata.obst["tree1"].nodes["C"]["with_missing"] == 3 + assert states["with_missing"][0] == 1.5 + + +def test_ancestral_state_fitch(tdata): + states = ancestral_states(tdata, "characters", method="fitch_hartigan", copy=True) + assert tdata.obst["tree1"].nodes["root"]["characters"] == ["1", "0"] + assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"] + assert states["characters"][0] == ["1", "0"] + + +def test_ancestral_states_sankoff(tdata): + costs = pd.DataFrame( + [[0, 1, 2], [10, 0, 10], [10, 10, 0]], + index=["0", "1", "2"], + columns=["0", "1", "2"], + ) + states = ancestral_states(tdata, "characters", method="sankoff", costs=costs, copy=True) + assert tdata.obst["tree1"].nodes["root"]["characters"] == ["0", "0"] + assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"] + assert states["characters"][0] == ["0", "0"] + costs = pd.DataFrame( + [[0, 10, 10], [1, 0, 2], [2, 1, 0]], + index=["0", "1", "2"], + columns=["0", "1", "2"], + ) + states = ancestral_states(tdata, "characters", method="sankoff", costs=costs, copy=True) + assert tdata.obst["tree1"].nodes["root"]["characters"] == ["2", "1"] + + +def test_ancestral_states_invalid(tdata): + with pytest.raises(ValueError): + ancestral_states(tdata, "characters", method="sankoff") + with pytest.raises(ValueError): + ancestral_states(tdata, "characters", method="sankoff", costs=pd.DataFrame()) + with pytest.raises(ValueError): + ancestral_states(tdata, "bad", method="mean") + with pytest.raises(ValueError): + ancestral_states(tdata, "value", method="bad") + with pytest.raises(ValueError): + ancestral_states(tdata, "value", method="fitch_hartigan", copy=False) + with pytest.raises(ValueError): + ancestral_states(tdata, "str_value", method="mean", copy=False) From ddcedf237f648b02fe583171159c51ba00c8487a Mon Sep 17 00:00:00 2001 From: colganwi Date: Mon, 27 May 2024 20:17:07 -0400 Subject: [PATCH 3/4] updated utils --- src/pycea/pl/plot_tree.py | 26 +++++++++--- src/pycea/tl/ancestral_states.py | 13 +----- src/pycea/utils.py | 70 ++++++++++++++++---------------- tests/test_ancestral_states.py | 9 ++-- tests/test_utils.py | 24 +++++------ 5 files changed, 72 insertions(+), 70 deletions(-) diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index a35b3e6..ef9172d 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -101,7 +101,11 @@ def branches( if mcolors.is_color_like(color): kwargs.update({"color": color}) elif isinstance(color, str): - color_data = get_keyed_edge_data(trees, color) + color_data = get_keyed_edge_data(tdata, color, tree_keys)[color] + print(color_data) + if len(color_data) == 0: + raise ValueError(f"Key {color!r} is not present in any edge.") + color_data.index = color_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}") if color_data.dtype.kind in ["i", "f"]: norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max()) cmap = plt.get_cmap(cmap) @@ -125,7 +129,10 @@ def branches( if isinstance(linewidth, (int, float)): kwargs.update({"linewidth": linewidth}) elif isinstance(linewidth, str): - linewidth_data = get_keyed_edge_data(trees, linewidth) + linewidth_data = get_keyed_edge_data(tdata, linewidth, tree_keys)[linewidth] + if len(linewidth_data) == 0: + raise ValueError(f"Key {linewidth!r} is not present in any edge.") + linewidth_data.index = linewidth_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}") if linewidth_data.dtype.kind in ["i", "f"]: linewidths = [linewidth_data[edge] if edge in linewidth_data.index else na_linewidth for edge in edges] kwargs.update({"linewidth": linewidths}) @@ -271,7 +278,10 @@ def nodes( if mcolors.is_color_like(color): kwargs.update({"color": color}) elif isinstance(color, str): - color_data = get_keyed_node_data(trees, color) + color_data = get_keyed_node_data(tdata, color, tree_keys)[color] + if len(color_data) == 0: + raise ValueError(f"Key {color!r} is not present in any node.") + color_data.index = color_data.index.map("-".join) if color_data.dtype.kind in ["i", "f"]: if not vmin: vmin = color_data.min() @@ -298,7 +308,10 @@ def nodes( if isinstance(size, (int, float)): kwargs.update({"s": size}) elif isinstance(size, str): - size_data = get_keyed_node_data(trees, size) + size_data = get_keyed_node_data(tdata, size, tree_keys)[size] + if len(size_data) == 0: + raise ValueError(f"Key {size!r} is not present in any node.") + size_data.index = size_data.index.map("-".join) sizes = [size_data[node] if node in size_data.index else na_size for node in nodes] kwargs.update({"s": sizes}) else: @@ -307,7 +320,10 @@ def nodes( if style in mmarkers.MarkerStyle.markers: kwargs.update({"marker": style}) elif isinstance(style, str): - style_data = get_keyed_node_data(trees, style) + style_data = get_keyed_node_data(tdata, style, tree_keys)[style] + if len(style_data) == 0: + raise ValueError(f"Key {style!r} is not present in any node.") + style_data.index = style_data.index.map("-".join) mmap = _get_categorical_markers(tdata, style, style_data, markers) styles = [mmap[style_data[node]] if node in style_data.index else na_style for node in nodes] for style in set(styles): diff --git a/src/pycea/tl/ancestral_states.py b/src/pycea/tl/ancestral_states.py index de57ed1..bb5fc2d 100755 --- a/src/pycea/tl/ancestral_states.py +++ b/src/pycea/tl/ancestral_states.py @@ -257,15 +257,4 @@ def ancestral_states( nx.set_node_attributes(tree, data[key].to_dict(), key) _ancestral_states(tree, key, method, missing_state, default_state) if copy: - states = [] - for name, tree in trees.items(): - tree_states = [] - for key in keys: - data = get_keyed_node_data(tree, key) - tree_states.append(data) - tree_states = pd.concat(tree_states, axis=1) - tree_states["tree"] = name - states.append(tree_states) - states = pd.concat(states) - states["node"] = states.index - return states.reset_index(drop=True) + return get_keyed_node_data(tdata, keys, tree_keys) diff --git a/src/pycea/utils.py b/src/pycea/utils.py index 3e77a49..7648411 100755 --- a/src/pycea/utils.py +++ b/src/pycea/utils.py @@ -24,48 +24,46 @@ def get_leaves(tree: nx.DiGraph): return [node for node in nx.dfs_postorder_nodes(tree, get_root(tree)) if tree.out_degree(node) == 0] -def get_keyed_edge_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) -> pd.Series: +def get_keyed_edge_data( + tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None +) -> pd.DataFrame: """Gets edge data for a given key from a tree or set of trees.""" - if isinstance(tree, nx.DiGraph): - trees = {"": tree} - sep = "" - else: - trees = tree - sep = "-" - edge_data = {} + if isinstance(tree_keys, str): + tree_keys = [tree_keys] + if isinstance(keys, str): + keys = [keys] + trees = get_trees(tdata, tree_keys) + data = [] for name, tree in trees.items(): - edge_data.update( - { - (f"{name}{sep}{parent}", f"{name}{sep}{child}"): data.get(key) - for parent, child, data in tree.edges(data=True) - if key in data and data[key] is not None - } - ) - if len(edge_data) == 0: - raise ValueError(f"Key {key!r} is not present in any edge.") - return pd.Series(edge_data, name=key) + edge_data = {key: nx.get_edge_attributes(tree, key) for key in keys} + edge_data = pd.DataFrame(edge_data) + edge_data["tree"] = name + edge_data["edge"] = edge_data.index + data.append(edge_data) + data = pd.concat(data) + data = data.set_index(["tree", "edge"]) + return data -def get_keyed_node_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) -> pd.Series: +def get_keyed_node_data( + tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None +) -> pd.DataFrame: """Gets node data for a given key a tree or set of trees.""" - if isinstance(tree, nx.DiGraph): - trees = {"": tree} - sep = "" - else: - trees = tree - sep = "-" - node_data = {} + if isinstance(tree_keys, str): + tree_keys = [tree_keys] + if isinstance(keys, str): + keys = [keys] + trees = get_trees(tdata, tree_keys) + data = [] for name, tree in trees.items(): - node_data.update( - { - f"{name}{sep}{node}": data.get(key) - for node, data in tree.nodes(data=True) - if key in data and data[key] is not None - } - ) - if len(node_data) == 0: - raise ValueError(f"Key {key!r} is not present in any node.") - return pd.Series(node_data, name=key) + tree_data = {key: nx.get_node_attributes(tree, key) for key in keys} + tree_data = pd.DataFrame(tree_data) + tree_data["tree"] = name + data.append(tree_data) + data = pd.concat(data) + data["node"] = data.index + data = data.set_index(["tree", "node"]) + return data def get_keyed_obs_data(tdata: td.TreeData, keys: Sequence[str], layer: str = None) -> pd.DataFrame: diff --git a/tests/test_ancestral_states.py b/tests/test_ancestral_states.py index e541a05..8ea8662 100755 --- a/tests/test_ancestral_states.py +++ b/tests/test_ancestral_states.py @@ -46,7 +46,7 @@ def test_ancestral_states_array(tdata): print(states) assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 2.0] assert tdata.obst["tree1"].nodes["C"]["spatial"] == [1.5, 1.0] - assert states["spatial"][0] == [1.0, 2.0] + assert states.loc[("tree1", "root"), "spatial"] == [1.0, 2.0] # Median states = ancestral_states(tdata, "spatial", method=np.median, copy=True) assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 1.0] @@ -58,14 +58,15 @@ def test_ancestral_states_missing(tdata): print(states) assert tdata.obst["tree1"].nodes["root"]["with_missing"] == 1.5 assert tdata.obst["tree1"].nodes["C"]["with_missing"] == 3 - assert states["with_missing"][0] == 1.5 + assert states.loc[("tree1", "root"), "with_missing"] == 1.5 def test_ancestral_state_fitch(tdata): states = ancestral_states(tdata, "characters", method="fitch_hartigan", copy=True) assert tdata.obst["tree1"].nodes["root"]["characters"] == ["1", "0"] assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"] - assert states["characters"][0] == ["1", "0"] + print(states) + assert states.loc[("tree1", "root"), "characters"] == ["1", "0"] def test_ancestral_states_sankoff(tdata): @@ -77,7 +78,7 @@ def test_ancestral_states_sankoff(tdata): states = ancestral_states(tdata, "characters", method="sankoff", costs=costs, copy=True) assert tdata.obst["tree1"].nodes["root"]["characters"] == ["0", "0"] assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"] - assert states["characters"][0] == ["0", "0"] + assert states.loc[("tree1", "root"), "characters"] == ["0", "0"] costs = pd.DataFrame( [[0, 10, 10], [1, 0, 2], [2, 1, 0]], index=["0", "1", "2"], diff --git a/tests/test_utils.py b/tests/test_utils.py index 3f401d5..22f8491 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_root, get_leaves +from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_leaves, get_root @pytest.fixture @@ -31,20 +31,14 @@ def test_get_leaves(tree): assert get_leaves(nx.DiGraph()) == [] -def test_get_keyed_edge_data(tree): - result = get_keyed_edge_data(tree, "weight") - expected_keys = [("A", "B"), ("B", "D"), ("C", "E")] - expected_values = [5, 3, 4] - assert all(result[key] == value for key, value in zip(expected_keys, expected_values)) - assert ("A", "C") not in result +def test_get_keyed_edge_data(tdata): + data = get_keyed_edge_data(tdata, ["length", "clade"]) + assert data.columns.tolist() == ["length", "clade"] -def test_get_keyed_node_data(tree): - result = get_keyed_node_data(tree, "value") - expected_keys = ["A", "B", "D", "E"] - expected_values = [1, 2, 4, 5] - assert all(result[key] == value for key, value in zip(expected_keys, expected_values)) - assert "C" not in result +def test_get_keyed_node_data(tdata): + data = get_keyed_node_data(tdata, ["x", "y", "clade"]) + assert data.columns.tolist() == ["x", "y", "clade"] def test_get_keyed_obs_data_valid_keys(tdata): @@ -70,3 +64,7 @@ def test_get_keyed_obs_data_invalid_keys(tdata): get_keyed_obs_data(tdata, ["clade", "x", "0", "invalid_key"]) with pytest.raises(ValueError): get_keyed_obs_data(tdata, ["clade", "spatial_distance"]) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) From 1866f437fc7ef13e66187f44705e17af2163d2fc Mon Sep 17 00:00:00 2001 From: colganwi Date: Tue, 28 May 2024 17:41:09 -0400 Subject: [PATCH 4/4] cleaned up tests --- tests/test_ancestral_states.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_ancestral_states.py b/tests/test_ancestral_states.py index 8ea8662..085a3e4 100755 --- a/tests/test_ancestral_states.py +++ b/tests/test_ancestral_states.py @@ -43,7 +43,6 @@ def test_ancestral_states(tdata): def test_ancestral_states_array(tdata): # Mean states = ancestral_states(tdata, "spatial", method="mean", copy=True) - print(states) assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 2.0] assert tdata.obst["tree1"].nodes["C"]["spatial"] == [1.5, 1.0] assert states.loc[("tree1", "root"), "spatial"] == [1.0, 2.0] @@ -55,7 +54,6 @@ def test_ancestral_states_array(tdata): def test_ancestral_states_missing(tdata): # Mean states = ancestral_states(tdata, "with_missing", method=np.nanmean, copy=True) - print(states) assert tdata.obst["tree1"].nodes["root"]["with_missing"] == 1.5 assert tdata.obst["tree1"].nodes["C"]["with_missing"] == 3 assert states.loc[("tree1", "root"), "with_missing"] == 1.5 @@ -65,7 +63,6 @@ def test_ancestral_state_fitch(tdata): states = ancestral_states(tdata, "characters", method="fitch_hartigan", copy=True) assert tdata.obst["tree1"].nodes["root"]["characters"] == ["1", "0"] assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"] - print(states) assert states.loc[("tree1", "root"), "characters"] == ["1", "0"]