Skip to content

Commit

Permalink
Merge pull request #10 from YosefLab/plotting-update
Browse files Browse the repository at this point in the history
Plotting update
  • Loading branch information
colganwi authored Aug 23, 2024
2 parents 9687424 + ae4bc66 commit 6c5e760
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ python:
path: .
extra_requirements:
- doc
formats:
- pdf
7 changes: 3 additions & 4 deletions src/pycea/pl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from scanpy.plotting import palettes

from pycea.utils import get_leaves, get_root
from pycea.utils import check_tree_has_key, get_leaves, get_root


def layout_nodes_and_branches(
Expand Down Expand Up @@ -117,13 +117,12 @@ def layout_trees(
leaves = []
depths = []
for _, tree in trees.items():
check_tree_has_key(tree, depth_key)
tree_leaves = get_leaves(tree)
leaves.extend(tree_leaves)
depths.extend(tree.nodes[leaf].get(depth_key) for leaf in tree_leaves)
if len(depths) != len(leaves):
raise ValueError(
f"Tree does not have {depth_key} attribute. You can run `pycea.pp.add_depth` to add depth attribute."
)
raise ValueError(f"Every node in the tree must have a {depth_key} attribute. ")
max_depth = max(depths)
n_leaves = len(leaves)
leaf_coords = {}
Expand Down
12 changes: 7 additions & 5 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def branches(
""" # noqa: D205
# Setup
tree_keys = tree
if not ax:
ax = plt.gca()
if (ax.name == "polar" and not polar) or (ax.name != "polar" and polar):
if ax is None:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None)
elif (ax.name == "polar" and not polar) or (ax.name != "polar" and polar):
warnings.warn("Polar setting of axes does not match requested type. Creating new axes.", stacklevel=2)
fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None)
kwargs = kwargs if kwargs else {}
Expand Down Expand Up @@ -492,7 +492,7 @@ def tree(
tdata: td.TreeData,
keys: str | Sequence[str] = None,
tree: str | Sequence[str] | None = None,
nodes: str | Sequence[str] = None,
nodes: str | Sequence[str] | None = None,
polar: bool = False,
extend_branches: bool = False,
angled_branches: bool = False,
Expand Down Expand Up @@ -520,7 +520,7 @@ def tree(
tree
The `obst` key or keys of the trees to plot. If `None`, all trees are plotted.
nodes
Either "all", "leaves", "internal", or a list of nodes to plot.
Either "all", "leaves", "internal", or a list of nodes to plot. Defaults to "internal" if node color, style, or size is set.
polar
Whether to plot the tree in polar coordinates.
extend_branches
Expand Down Expand Up @@ -562,6 +562,8 @@ def tree(
ax=ax,
)
# Plot nodes
if nodes is None and (node_color != "black" or node_style != "o" or node_size != 10):
nodes = "internal"
if nodes:
ax = _nodes(
tdata,
Expand Down
8 changes: 0 additions & 8 deletions src/pycea/tl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ def _format_as_list(obj):
return obj


def _check_tree_overlap(tdata, tree_keys):
"""If overlap is allowed there can only be one tree"""
n_trees = len(tdata.obst.keys())
if (n_trees > 1) and tdata.allow_overlap and len(tree_keys) != 1:
raise ValueError("Must specify a singe tree if tdata.allow_overlap is True.")
return


def _set_distances_and_connectivities(tdata, key_added, dist, connect, update):
"""Set distances and connectivities in tdata"""
dist_key = f"{key_added}_distances"
Expand Down
2 changes: 0 additions & 2 deletions src/pycea/tl/tree_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ._metrics import _get_tree_metric, _TreeMetric
from ._utils import (
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_format_keys,
_set_distances_and_connectivities,
Expand Down Expand Up @@ -168,7 +167,6 @@ def tree_distance(
trees = get_trees(tdata, tree_keys)
metric_fn = _get_tree_metric(metric)
single_obs = False
_check_tree_overlap(tdata, tree_keys)
if update:
_check_previous_params(tdata, {"metric": metric}, key_added, ["neighbors", "distances"])
# Get set of pairs for each tree
Expand Down
6 changes: 2 additions & 4 deletions src/pycea/tl/tree_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ._utils import (
_assert_param_xor,
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_set_distances_and_connectivities,
_set_random_state,
Expand Down Expand Up @@ -50,7 +49,7 @@ def _bfs_by_distance(tree, start_node, n_neighbors, max_dist, metric, depth_key)
if len(neighbors) >= n_neighbors:
break
heapq.heappush(queue, (child_distance, child))
visited.add(child)
visited.add(child)
# Add parents to queue
for parent in nx.ancestors(tree, node):
if parent not in visited:
Expand All @@ -60,7 +59,7 @@ def _bfs_by_distance(tree, start_node, n_neighbors, max_dist, metric, depth_key)
parent_distance = tree.nodes[parent][depth_key]
if parent_distance <= max_dist:
heapq.heappush(queue, (parent_distance, parent))
visited.add(parent)
visited.add(parent)
return neighbors, neighbor_distances


Expand Down Expand Up @@ -143,7 +142,6 @@ def tree_neighbors(
_assert_param_xor({"n_neighbors": n_neighbors, "max_dist": max_dist})
_ = _get_tree_metric(metric)
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
if update:
_check_previous_params(tdata, {"metric": metric}, key_added, ["neighbors", "distances"])
# Neighbors of a single leaf
Expand Down
2 changes: 1 addition & 1 deletion src/pycea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def check_tree_has_key(tree: nx.DiGraph, key: str):
sampled_nodes = random.sample(list(tree.nodes), min(10, len(tree.nodes)))
for node in sampled_nodes:
if key not in tree.nodes[node]:
message = f"Tree does not have {key} attribute."
message = f"Tree nodes to not have {key} attribute."
if key == "depth":
message += " You can run `pycea.pp.add_depth` to add depth attribute."
raise ValueError(message)
Expand Down

0 comments on commit 6c5e760

Please sign in to comment.