Skip to content

Commit

Permalink
Merge pull request #5 from YosefLab/ancestral-states
Browse files Browse the repository at this point in the history
ancestral states
  • Loading branch information
colganwi authored May 28, 2024
2 parents e89beb6 + 1866f43 commit e5596a2
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 55 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
.. autosummary::
:toctree: generated
tl.ancestral_states
tl.clades
tl.sort
```
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"numpy",
"pandas",
"session-info",
"scipy",
]

[project.optional-dependencies]
Expand Down
26 changes: 21 additions & 5 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def branches(
if mcolors.is_color_like(color):
kwargs.update({"color": color})
elif isinstance(color, str):
color_data = get_keyed_edge_data(trees, color)
color_data = get_keyed_edge_data(tdata, color, tree_keys)[color]
print(color_data)
if len(color_data) == 0:
raise ValueError(f"Key {color!r} is not present in any edge.")
color_data.index = color_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}")
if color_data.dtype.kind in ["i", "f"]:
norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max())
cmap = plt.get_cmap(cmap)
Expand All @@ -125,7 +129,10 @@ def branches(
if isinstance(linewidth, (int, float)):
kwargs.update({"linewidth": linewidth})
elif isinstance(linewidth, str):
linewidth_data = get_keyed_edge_data(trees, linewidth)
linewidth_data = get_keyed_edge_data(tdata, linewidth, tree_keys)[linewidth]
if len(linewidth_data) == 0:
raise ValueError(f"Key {linewidth!r} is not present in any edge.")
linewidth_data.index = linewidth_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}")
if linewidth_data.dtype.kind in ["i", "f"]:
linewidths = [linewidth_data[edge] if edge in linewidth_data.index else na_linewidth for edge in edges]
kwargs.update({"linewidth": linewidths})
Expand Down Expand Up @@ -271,7 +278,10 @@ def nodes(
if mcolors.is_color_like(color):
kwargs.update({"color": color})
elif isinstance(color, str):
color_data = get_keyed_node_data(trees, color)
color_data = get_keyed_node_data(tdata, color, tree_keys)[color]
if len(color_data) == 0:
raise ValueError(f"Key {color!r} is not present in any node.")
color_data.index = color_data.index.map("-".join)
if color_data.dtype.kind in ["i", "f"]:
if not vmin:
vmin = color_data.min()
Expand All @@ -298,7 +308,10 @@ def nodes(
if isinstance(size, (int, float)):
kwargs.update({"s": size})
elif isinstance(size, str):
size_data = get_keyed_node_data(trees, size)
size_data = get_keyed_node_data(tdata, size, tree_keys)[size]
if len(size_data) == 0:
raise ValueError(f"Key {size!r} is not present in any node.")
size_data.index = size_data.index.map("-".join)
sizes = [size_data[node] if node in size_data.index else na_size for node in nodes]
kwargs.update({"s": sizes})
else:
Expand All @@ -307,7 +320,10 @@ def nodes(
if style in mmarkers.MarkerStyle.markers:
kwargs.update({"marker": style})
elif isinstance(style, str):
style_data = get_keyed_node_data(trees, style)
style_data = get_keyed_node_data(tdata, style, tree_keys)[style]
if len(style_data) == 0:
raise ValueError(f"Key {style!r} is not present in any node.")
style_data.index = style_data.index.map("-".join)
mmap = _get_categorical_markers(tdata, style, style_data, markers)
styles = [mmap[style_data[node]] if node in style_data.index else na_style for node in nodes]
for style in set(styles):
Expand Down
1 change: 1 addition & 0 deletions src/pycea/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .clades import clades
from .sort import sort
from .ancestral_states import ancestral_states
260 changes: 260 additions & 0 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
from __future__ import annotations

from collections.abc import Sequence

import networkx as nx
import numpy as np
import pandas as pd
import treedata as td

from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_root, get_trees


def _most_common(arr):
"""Finds the most common element in a list."""
unique_values, counts = np.unique(arr, return_counts=True)
most_common_index = np.argmax(counts)
return unique_values[most_common_index]


def _get_node_value(tree, node, key, index):
"""Gets the value of a node attribute."""
if key in tree.nodes[node]:
if index is not None:
return tree.nodes[node][key][index]
else:
return tree.nodes[node][key]
else:
return None


def _set_node_value(tree, node, key, value, index):
"""Sets the value of a node attribute."""
if index is not None:
tree.nodes[node][key][index] = value
else:
tree.nodes[node][key] = value


def _reconstruct_fitch_hartigan(tree, key, missing="-1", index=None):
"""Reconstructs ancestral states using the Fitch-Hartigan algorithm."""

# Recursive function to calculate the downpass
def downpass(node):
# Base case: leaf
if tree.out_degree(node) == 0:
value = _get_node_value(tree, node, key, index)
if value == missing:
tree.nodes[node]["value_set"] = missing
else:
tree.nodes[node]["value_set"] = {value}
# Recursive case: internal node
else:
value_sets = []
for child in tree.successors(node):
downpass(child)
value_set = tree.nodes[child]["value_set"]
if value_set != missing:
value_sets.append(value_set)
if len(value_sets) > 0:
intersection = set.intersection(*value_sets)
if intersection:
tree.nodes[node]["value_set"] = intersection
else:
tree.nodes[node]["value_set"] = set.union(*value_sets)
else:
tree.nodes[node]["value_set"] = missing

# Recursive function to calculate the uppass
def uppass(node, parent_state=None):
value = _get_node_value(tree, node, key, index)
if value is None:
if parent_state and parent_state in tree.nodes[node]["value_set"]:
value = parent_state
else:
value = min(tree.nodes[node]["value_set"])
_set_node_value(tree, node, key, value, index)
elif value == missing:
value = parent_state
_set_node_value(tree, node, key, value, index)
for child in tree.successors(node):
uppass(child, value)

# Run the algorithm
root = get_root(tree)
downpass(root)
uppass(root)
# Clean up
for node in tree.nodes:
if "value_set" in tree.nodes[node]:
del tree.nodes[node]["value_set"]


def _reconstruct_sankoff(tree, key, costs, missing="-1", index=None):
"""Reconstructs ancestral states using the Sankoff algorithm."""

# Recursive function to calculate the Sankoff scores
def sankoff_scores(node):
# Base case: leaf
if tree.out_degree(node) == 0:
leaf_value = _get_node_value(tree, node, key, index)
if leaf_value == missing:
return {value: 0 for value in alphabet}
else:
return {value: 0 if value == leaf_value else float("inf") for value in alphabet}
# Recursive case: internal node
else:
scores = {value: 0 for value in alphabet}
pointers = {value: {} for value in alphabet}
for child in tree.successors(node):
child_scores = sankoff_scores(child)
for value in alphabet:
min_cost, min_value = float("inf"), None
for child_value in alphabet:
cost = child_scores[child_value] + costs.loc[value, child_value]
if cost < min_cost:
min_cost, min_value = cost, child_value
scores[value] += min_cost
pointers[value][child] = min_value
tree.nodes[node]["_pointers"] = pointers
return scores

# Recursive function to traceback the Sankoff scores
def traceback(node, parent_value=None):
for child in tree.successors(node):
child_value = tree.nodes[node]["_pointers"][parent_value][child]
_set_node_value(tree, child, key, child_value, index)
traceback(child, child_value)

# Get scores
root = get_root(tree)
alphabet = set(costs.index)
root_scores = sankoff_scores(root)
# Reconstruct ancestral states
root_value = min(root_scores, key=root_scores.get)
_set_node_value(tree, root, key, root_value, index)
traceback(root, root_value)
# Clean up
for node in tree.nodes:
if "_pointers" in tree.nodes[node]:
del tree.nodes[node]["_pointers"]


def _reconstruct_mean(tree, key, index):
"""Reconstructs ancestral by averaging the values of the children."""

def subtree_mean(node):
if tree.out_degree(node) == 0:
return _get_node_value(tree, node, key, index), 1
else:
values, weights = [], []
for child in tree.successors(node):
child_value, child_n = subtree_mean(child)
values.append(child_value)
weights.append(child_n)
mean_value = np.average(values, weights=weights)
_set_node_value(tree, node, key, mean_value, index)
return mean_value, sum(weights)

root = get_root(tree)
subtree_mean(root)


def _reconstruct_list(tree, key, sum_func, index):
"""Reconstructs ancestral states by concatenating the values of the children."""

def subtree_list(node):
if tree.out_degree(node) == 0:
return [_get_node_value(tree, node, key, index)]
else:
values = []
for child in tree.successors(node):
values.extend(subtree_list(child))
_set_node_value(tree, node, key, sum_func(values), index)
return values

root = get_root(tree)
subtree_list(root)


def _ancestral_states(tree, key, method="mean", costs=None, missing=None, default=None, index=None):
"""Reconstructs ancestral states for a given attribute using a given method"""
if method == "sankoff":
if costs is None:
raise ValueError("Costs matrix must be provided for Sankoff algorithm.")
_reconstruct_sankoff(tree, key, costs, missing, index)
elif method == "fitch_hartigan":
_reconstruct_fitch_hartigan(tree, key, missing, index)
elif method == "mean":
_reconstruct_mean(tree, key, index)
elif method == "mode":
_reconstruct_list(tree, key, _most_common, index)
elif callable(method):
_reconstruct_list(tree, key, method, index)
else:
raise ValueError(f"Method {method} not recognized.")


def ancestral_states(
tdata: td.TreeData,
keys: str | Sequence[str],
method: str = "mean",
missing_state: str = "-1",
default_state: str = "0",
costs: pd.DataFrame = None,
tree: str | Sequence[str] | None = None,
copy: bool = False,
) -> None:
"""Reconstructs ancestral states for an attribute.
Parameters
----------
tdata
TreeData object.
keys
One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to reconstruct.
method
Method to reconstruct ancestral states. One of "mean", "mode", "fitch_hartigan", "sankoff",
or any function that takes a list of values and returns a single value.
missing_state
The state to consider as missing data.
default_state
The expected state for the root node.
costs
A pd.DataFrame with the costs of changing states (from rows to columns).
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
If True, returns a pd.DataFrame with ancestral states.
"""
if isinstance(keys, str):
keys = [keys]
tree_keys = tree
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
data, is_array = get_keyed_obs_data(tdata, keys)
dtypes = {dtype.kind for dtype in data.dtypes}
# Check data type
if dtypes.intersection({"i", "f"}):
if method in ["fitch_hartigan", "sankoff"]:
raise ValueError(f"Method {method} requires categorical data.")
if dtypes.intersection({"O", "S"}):
if method in ["mean"]:
raise ValueError(f"Method {method} requires numerical data.")
# If array add to tree as list
if is_array:
length = data.shape[1]
node_attrs = data.apply(lambda row: list(row), axis=1).to_dict()
for node in tree.nodes:
if node not in node_attrs:
node_attrs[node] = [None] * length
nx.set_node_attributes(tree, node_attrs, keys[0])
for index in range(length):
_ancestral_states(tree, keys[0], method, costs, missing_state, default_state, index)
# If column add to tree as scalar
else:
for key in keys:
nx.set_node_attributes(tree, data[key].to_dict(), key)
_ancestral_states(tree, key, method, missing_state, default_state)
if copy:
return get_keyed_node_data(tdata, keys, tree_keys)
2 changes: 1 addition & 1 deletion src/pycea/tl/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _sort_tree(tree, key, reverse=False):


def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequence[str] | None = None) -> None:
"""Sorts the children of each internal node in a tree based on a given key.
"""Reorders branches based on a given key.
Parameters
----------
Expand Down
Loading

0 comments on commit e5596a2

Please sign in to comment.