Skip to content

Commit

Permalink
Merge pull request #7 from YosefLab/distance
Browse files Browse the repository at this point in the history
Distance
  • Loading branch information
colganwi authored Aug 17, 2024
2 parents 8fffb2c + b7aa66b commit 995f12a
Show file tree
Hide file tree
Showing 27 changed files with 1,573 additions and 155 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
cache: "pip"
cache-dependency-path: "**/pyproject.toml"
- name: Install build dependencies
Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,9 @@ jobs:
matrix:
include:
- os: ubuntu-latest
python: "3.9"
python: "3.10"
- os: ubuntu-latest
python: "3.11"
- os: ubuntu-latest
python: "3.11"
pip-flags: "--pre"
name: PRE-RELEASE DEPENDENCIES
python: "3.12"

name: ${{ matrix.name }} Python ${{ matrix.python }}

Expand Down
4 changes: 4 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
tl.ancestral_states
tl.clades
tl.compare_distance
tl.distance
tl.sort
tl.tree_distance
tl.tree_neighbors
```

## Plotting
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"pandas",
"session-info",
"scipy",
"scikit-learn",
]

[project.optional-dependencies]
Expand Down
18 changes: 12 additions & 6 deletions src/pycea/pl/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Plotting utilities"""

from __future__ import annotations

import warnings
Expand Down Expand Up @@ -115,10 +116,14 @@ def layout_trees(
# Get leaf coordinates
leaves = []
depths = []
for _ , tree in trees.items():
for _, tree in trees.items():
tree_leaves = get_leaves(tree)
leaves.extend(tree_leaves)
depths.extend(tree.nodes[leaf].get(depth_key) for leaf in tree_leaves)
if len(depths) != len(leaves):
raise ValueError(
f"Tree does not have {depth_key} attribute. You can run `pycea.pp.add_depth` to add depth attribute."
)
max_depth = max(depths)
n_leaves = len(leaves)
leaf_coords = {}
Expand All @@ -132,13 +137,14 @@ def layout_trees(
node_coords = {}
branch_coords = {}
for key, tree in trees.items():
tree_node_coords,tree_branch_coords = layout_nodes_and_branches(tree, leaf_coords, depth_key, polar, angled_branches)
node_coords.update({f"{key}-{node}": coords for node, coords in tree_node_coords.items()})
branch_coords.update({(f"{key}-{parent}", f"{key}-{child}"): coords for (parent, child), coords in tree_branch_coords.items()})
tree_node_coords, tree_branch_coords = layout_nodes_and_branches(
tree, leaf_coords, depth_key, polar, angled_branches
)
node_coords.update({(key, node): coords for node, coords in tree_node_coords.items()})
branch_coords.update({(key, edge): coords for edge, coords in tree_branch_coords.items()})
return node_coords, branch_coords, leaves, max_depth



def _get_default_categorical_colors(length):
"""Get default categorical colors for plotting."""
# check if default matplotlib palette has enough colors
Expand Down Expand Up @@ -257,7 +263,7 @@ def _series_to_rgb_array(series, colors, vmin=None, vmax=None, na_color="#808080
"""Converts a pandas Series to an N x 3 numpy array based using a color map."""
if isinstance(colors, dict):
# Map using the dictionary
color_series = series.map(colors)
color_series = series.map(colors).astype("object")
color_series[series.isna()] = na_color
rgb_array = np.array([mcolors.to_rgb(color) for color in color_series])
elif isinstance(colors, mcolors.ListedColormap):
Expand Down
25 changes: 8 additions & 17 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,8 @@ def branches(
kwargs.update({"color": color})
elif isinstance(color, str):
color_data = get_keyed_edge_data(tdata, color, tree_keys)[color]
print(color_data)
if len(color_data) == 0:
raise ValueError(f"Key {color!r} is not present in any edge.")
color_data.index = color_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}")
if color_data.dtype.kind in ["i", "f"]:
norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max())
cmap = plt.get_cmap(cmap)
Expand All @@ -132,7 +130,6 @@ def branches(
linewidth_data = get_keyed_edge_data(tdata, linewidth, tree_keys)[linewidth]
if len(linewidth_data) == 0:
raise ValueError(f"Key {linewidth!r} is not present in any edge.")
linewidth_data.index = linewidth_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}")
if linewidth_data.dtype.kind in ["i", "f"]:
linewidths = [linewidth_data[edge] if edge in linewidth_data.index else na_linewidth for edge in edges]
kwargs.update({"linewidth": linewidths})
Expand Down Expand Up @@ -244,23 +241,21 @@ def nodes(
tree_keys = [tree_keys]
if not set(tree_keys).issubset(attrs["tree_keys"]):
raise ValueError("Invalid tree key. Must be one of the keys used to plot the branches.")
trees = get_trees(tdata, attrs["tree_keys"])
# Get nodes
all_nodes = set()
all_nodes = []
for node in list(attrs["node_coords"].keys()):
if any(node.startswith(key) for key in tree_keys):
all_nodes.add(node)
leaves = set(attrs["leaves"])
if node[0] in tree_keys:
all_nodes.append(node)
if nodes == "all":
nodes = list(all_nodes)
nodes = all_nodes
elif nodes == "leaves":
nodes = list(all_nodes.intersection(leaves))
nodes = [node for node in all_nodes if node[1] in attrs["leaves"]]
elif nodes == "internal":
nodes = list(all_nodes.difference(leaves))
nodes = [node for node in all_nodes if node[1] not in attrs["leaves"]]
elif isinstance(nodes, Sequence):
if len(attrs["tree_keys"]) > 1 and len(tree_keys) > 1:
raise ValueError("Multiple trees are present. To plot a list of nodes, you must specify the tree.")
nodes = [f"{tree_keys[0]}-{node}" for node in nodes]
nodes = [(tree_keys[0], node) for node in nodes]
if set(nodes).issubset(all_nodes):
nodes = list(nodes)
else:
Expand All @@ -281,7 +276,6 @@ def nodes(
color_data = get_keyed_node_data(tdata, color, tree_keys)[color]
if len(color_data) == 0:
raise ValueError(f"Key {color!r} is not present in any node.")
color_data.index = color_data.index.map("-".join)
if color_data.dtype.kind in ["i", "f"]:
if not vmin:
vmin = color_data.min()
Expand Down Expand Up @@ -311,7 +305,6 @@ def nodes(
size_data = get_keyed_node_data(tdata, size, tree_keys)[size]
if len(size_data) == 0:
raise ValueError(f"Key {size!r} is not present in any node.")
size_data.index = size_data.index.map("-".join)
sizes = [size_data[node] if node in size_data.index else na_size for node in nodes]
kwargs.update({"s": sizes})
else:
Expand All @@ -323,7 +316,6 @@ def nodes(
style_data = get_keyed_node_data(tdata, style, tree_keys)[style]
if len(style_data) == 0:
raise ValueError(f"Key {style!r} is not present in any node.")
style_data.index = style_data.index.map("-".join)
mmap = _get_categorical_markers(tdata, style, style_data, markers)
styles = [mmap[style_data[node]] if node in style_data.index else na_style for node in nodes]
for style in set(styles):
Expand Down Expand Up @@ -417,7 +409,6 @@ def annotation(
leaves = attrs["leaves"]
# Get data
data, is_array = get_keyed_obs_data(tdata, keys)
data = data.loc[leaves]
numeric_data = data.select_dtypes(exclude="category")
if len(numeric_data) > 0 and not vmin:
vmin = numeric_data.min().min()
Expand Down Expand Up @@ -448,7 +439,7 @@ def annotation(
end_lat = start_lat + attrs["depth"] + 2 * np.pi
lats = np.linspace(start_lat, end_lat, data.shape[1] + 1)
for col in data.columns:
rgb_array.append(_series_to_rgb_array(data.loc[col], cmap, vmin=vmin, vmax=vmax, na_color=na_color))
rgb_array.append(_series_to_rgb_array(data.loc[leaves, col], cmap, vmin=vmin, vmax=vmax, na_color=na_color))
else:
for key in keys:
if data[key].dtype == "category":
Expand Down
31 changes: 22 additions & 9 deletions src/pycea/pp/setup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from collections.abc import Sequence

import networkx as nx
import pandas as pd
import treedata as td

from pycea.utils import get_keyed_node_data, get_root, get_trees
from pycea.utils import get_keyed_leaf_data, get_keyed_node_data, get_root, get_trees


def _add_depth(tree, depth_key):
Expand All @@ -16,24 +17,36 @@ def _add_depth(tree, depth_key):


def add_depth(
tdata: td.TreeData, depth_key: str = "depth", tree: str | Sequence[str] | None = None, copy: bool = False
):
"""Adds a depth attribute to the nodes of a tree.
tdata: td.TreeData, key_added: str = "depth", tree: str | Sequence[str] | None = None, copy: bool = False
) -> None | pd.DataFrame:
"""Adds a depth attribute to the tree.
Parameters
----------
tdata
TreeData object.
depth_key
Node attribute key to store the depth.
key_added
Key to store node depths.
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
If True, returns a pd.DataFrame node depths.
If True, returns a :class:`DataFrame <pandas.DataFrame>` with node depths.
Returns
-------
Returns `None` if `copy=False`, else returns node depths.
Sets the following fields:
* `tdata.obs[key_added]` : :class:`Series <pandas.Series>` (dtype `float`)
- Distance from the root node.
* `tdata.obst[tree].nodes[key_added]` : `float`
- Distance from the root node.
"""
tree_keys = tree
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
_add_depth(tree, depth_key)
_add_depth(tree, key_added)
tdata.obs[key_added] = get_keyed_leaf_data(tdata, key_added)[key_added]
if copy:
return get_keyed_node_data(tdata, depth_key)
return get_keyed_node_data(tdata, key_added, tree_keys)
5 changes: 4 additions & 1 deletion src/pycea/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .ancestral_states import ancestral_states
from .clades import clades
from .distance import compare_distance, distance
from .sort import sort
from .ancestral_states import ancestral_states
from .tree_distance import tree_distance
from .tree_neighbors import tree_neighbors
63 changes: 63 additions & 0 deletions src/pycea/tl/_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from collections.abc import Callable
from typing import Literal

import numpy as np
import treedata as td

_MetricFn = Callable[[np.ndarray, np.ndarray], float]

_Metric = Literal[
"braycurtis",
"canberra",
"chebyshev",
"cityblock",
"cosine",
"correlation",
"dice",
"euclidean",
"hamming",
"jaccard",
"kulsinski",
"l1",
"l2",
"mahalanobis",
"minkowski",
"manhattan",
"rogerstanimoto",
"russellrao",
"seuclidean",
"sokalmichener",
"sokalsneath",
"sqeuclidean",
"yule",
]


def _lca_distance(tree, depth_key, node1, node2, lca):
"""Compute the lca distance between two nodes in a tree."""
if node1 == node2:
return tree.nodes[node1][depth_key]
else:
return tree.nodes[lca][depth_key]


def _path_distance(tree, depth_key, node1, node2, lca):
"""Compute the path distance between two nodes in a tree."""
if node1 == node2:
return 0
else:
return abs(tree.nodes[node1][depth_key] + tree.nodes[node2][depth_key] - 2 * tree.nodes[lca][depth_key])


_TreeMetricFn = Callable[[td.TreeData, str, str, str, str], np.ndarray]

_TreeMetric = Literal["lca", "path"]


def _get_tree_metric(metric: str) -> _TreeMetricFn:
if metric == "lca":
return _lca_distance
elif metric == "path":
return _path_distance
else:
raise ValueError(f"Unknown metric: {metric}. Valid metrics are 'lca' and 'path'.")
Loading

0 comments on commit 995f12a

Please sign in to comment.