Skip to content

Commit

Permalink
Merge pull request #4 from YosefLab/clades
Browse files Browse the repository at this point in the history
Clades
  • Loading branch information
colganwi authored May 24, 2024
2 parents 992b3a8 + 0302d9a commit e89beb6
Show file tree
Hide file tree
Showing 13 changed files with 563 additions and 120 deletions.
11 changes: 11 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@

## Tools

```{eval-rst}
.. module:: pycea.tl
.. currentmodule:: pycea
.. autosummary::
:toctree: generated
tl.clades
tl.sort
```

## Plotting

```{eval-rst}
Expand Down
119 changes: 80 additions & 39 deletions src/pycea/pl/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Plotting utilities"""
from __future__ import annotations

import collections.abc as cabc
import warnings
from collections.abc import Mapping, Sequence

import cycler
import matplotlib as mpl
Expand All @@ -11,22 +12,24 @@
import numpy as np
from scanpy.plotting import palettes

from pycea.utils import get_root
from pycea.utils import get_leaves, get_root


def layout_tree(
def layout_nodes_and_branches(
tree: nx.DiGraph,
depth_key: str = "time",
leaf_coords: Mapping[str],
depth_key: str = "depth",
polar: bool = False,
extend_branches: bool = True,
angled_branches: bool = False,
):
"""Given a tree, computes the coordinates of the nodes and branches.
"""Given a tree and leaf coordinates, computes the coordinates of the nodes and branches.
Parameters
----------
tree
The `nx.DiGraph` representing the tree.
leaf_coords
A dictionary mapping leaves to their coordinates.
depth_key
The node attribute to use as the depth of the nodes.
polar
Expand All @@ -42,38 +45,15 @@ def layout_tree(
A dictionary mapping nodes to their coordinates.
branch_coords
A dictionary mapping edges to their coordinates.
leaves
A list of the leaves of the tree.
max_depth
The maximum depth of the tree.
"""
# Get node depths
n_leaves = 0
root = get_root(tree)
depths = {}
for node in tree.nodes():
if tree.out_degree(node) == 0:
n_leaves += 1
depths[node] = tree.nodes[node].get(depth_key)
max_depth = max(depths.values())
# Get node coordinates
i = 0
leaves = []
node_coords = {}
for node in nx.dfs_postorder_nodes(tree, root):
if tree.out_degree(node) == 0:
lon = (i / (n_leaves)) * 2 * np.pi # + 2 * np.pi / n_leaves
if extend_branches:
node_coords[node] = (max_depth, lon)
else:
node_coords[node] = (depths[node], lon)
leaves.append(node)
i += 1
else:
node_coords = leaf_coords.copy()
for node in nx.dfs_postorder_nodes(tree, get_root(tree)):
if tree.out_degree(node) != 0:
children = list(tree.successors(node))
min_lon = min(node_coords[child][1] for child in children)
max_lon = max(node_coords[child][1] for child in children)
node_coords[node] = (depths[node], (min_lon + max_lon) / 2)
node_coords[node] = (tree.nodes[node].get(depth_key), (min_lon + max_lon) / 2)
# Get branch coordinates
branch_coords = {}
for parent, child in tree.edges():
Expand All @@ -96,9 +76,69 @@ def layout_tree(
inter_lons = np.linspace(lons[0], lons[1], int(np.ceil(angle / min_angle)))
inter_lats = [lats[0]] * len(inter_lons)
branch_coords[(parent, child)] = (np.append(inter_lats, lats[-1]), np.append(inter_lons, lons[-1]))
return node_coords, branch_coords


def layout_trees(
trees: Mapping[str],
depth_key: str = "depth",
polar: bool = False,
extend_branches: bool = True,
angled_branches: bool = False,
):
"""Given a list of trees, computes the coordinates of the nodes and branches.
Parameters
----------
trees
A dictionary mapping tree names to `nx.DiGraph` representing the trees.
depth_key
The node attribute to use as the depth of the nodes.
polar
Whether to plot the tree in polar coordinates.
extend_branches
Whether to extend branches so the tips are at the same depth.
angled_branches
Whether to plot branches at an angle.
Returns
-------
node_coords
A dictionary mapping nodes to their coordinates.
branch_coords
A dictionary mapping edges to their coordinates.
leaves
A list of the leaves of the tree.
max_depth
The maximum depth of the tree.
"""
# Get leaf coordinates
leaves = []
depths = []
for _ , tree in trees.items():
tree_leaves = get_leaves(tree)
leaves.extend(tree_leaves)
depths.extend(tree.nodes[leaf].get(depth_key) for leaf in tree_leaves)
max_depth = max(depths)
n_leaves = len(leaves)
leaf_coords = {}
for i in range(n_leaves):
lon = (i / n_leaves) * 2 * np.pi
if extend_branches:
leaf_coords[leaves[i]] = (max_depth, lon)
else:
leaf_coords[leaves[i]] = (depths[i], lon)
# Layout trees
node_coords = {}
branch_coords = {}
for key, tree in trees.items():
tree_node_coords,tree_branch_coords = layout_nodes_and_branches(tree, leaf_coords, depth_key, polar, angled_branches)
node_coords.update({f"{key}-{node}": coords for node, coords in tree_node_coords.items()})
branch_coords.update({(f"{key}-{parent}", f"{key}-{child}"): coords for (parent, child), coords in tree_branch_coords.items()})
return node_coords, branch_coords, leaves, max_depth



def _get_default_categorical_colors(length):
"""Get default categorical colors for plotting."""
# check if default matplotlib palette has enough colors
Expand Down Expand Up @@ -133,20 +173,20 @@ def _get_categorical_colors(tdata, key, data, palette=None):
# Use default colors if no palette is provided
if palette is None:
colors_list = tdata.uns.get(key + "_colors", None)
if colors_list is None or len(colors_list) > len(categories):
if (colors_list is None) or (len(colors_list) < len(categories)):
colors_list = _get_default_categorical_colors(len(categories))
# Use provided palette
else:
if isinstance(palette, str) and palette in plt.colormaps():
# this creates a palette from a colormap. E.g. 'Accent, Dark2, tab20'
cmap = plt.get_cmap(palette)
colors_list = [mcolors.to_hex(x, keep_alpha=True) for x in cmap(np.linspace(0, 1, len(categories)))]
elif isinstance(palette, cabc.Mapping):
elif isinstance(palette, Mapping):
colors_list = [mcolors.to_hex(palette[k], keep_alpha=True) for k in categories]
else:
# check if palette is a list and convert it to a cycler, thus
# it doesnt matter if the list is shorter than the categories length:
if isinstance(palette, cabc.Sequence):
if isinstance(palette, Sequence):
if len(palette) < len(categories):
warnings.warn(
"Length of palette colors is smaller than the number of "
Expand All @@ -173,7 +213,8 @@ def _get_categorical_colors(tdata, key, data, palette=None):
cc = palette()
colors_list = [mcolors.to_hex(next(cc)["color"], keep_alpha=True) for x in range(len(categories))]
# store colors in tdata
tdata.uns[key + "_colors"] = colors_list
if len(categories) <= len(palettes.default_102):
tdata.uns[key + "_colors"] = colors_list
return dict(zip(categories, colors_list))


Expand All @@ -191,10 +232,10 @@ def _get_categorical_markers(tdata, key, data, markers=None):
markers_list = default_markers[: len(categories)]
# Use provided markers
else:
if isinstance(markers, cabc.Mapping):
if isinstance(markers, Mapping):
markers_list = [markers[k] for k in categories]
else:
if not isinstance(markers, cabc.Sequence):
if not isinstance(markers, Sequence):
raise ValueError("Please check that the value of 'markers' is a valid " "list of marker names.")
if len(markers) < len(categories):
warnings.warn(
Expand Down
Loading

0 comments on commit e89beb6

Please sign in to comment.