Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting update #10

Merged
merged 5 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ python:
path: .
extra_requirements:
- doc
formats:
- pdf
7 changes: 3 additions & 4 deletions src/pycea/pl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from scanpy.plotting import palettes

from pycea.utils import get_leaves, get_root
from pycea.utils import check_tree_has_key, get_leaves, get_root


def layout_nodes_and_branches(
Expand Down Expand Up @@ -117,13 +117,12 @@ def layout_trees(
leaves = []
depths = []
for _, tree in trees.items():
check_tree_has_key(tree, depth_key)
tree_leaves = get_leaves(tree)
leaves.extend(tree_leaves)
depths.extend(tree.nodes[leaf].get(depth_key) for leaf in tree_leaves)
if len(depths) != len(leaves):
raise ValueError(
f"Tree does not have {depth_key} attribute. You can run `pycea.pp.add_depth` to add depth attribute."
)
raise ValueError(f"Every node in the tree must have a {depth_key} attribute. ")
max_depth = max(depths)
n_leaves = len(leaves)
leaf_coords = {}
Expand Down
12 changes: 7 additions & 5 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def branches(
""" # noqa: D205
# Setup
tree_keys = tree
if not ax:
ax = plt.gca()
if (ax.name == "polar" and not polar) or (ax.name != "polar" and polar):
if ax is None:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None)
elif (ax.name == "polar" and not polar) or (ax.name != "polar" and polar):
warnings.warn("Polar setting of axes does not match requested type. Creating new axes.", stacklevel=2)
fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None)
kwargs = kwargs if kwargs else {}
Expand Down Expand Up @@ -492,7 +492,7 @@ def tree(
tdata: td.TreeData,
keys: str | Sequence[str] = None,
tree: str | Sequence[str] | None = None,
nodes: str | Sequence[str] = None,
nodes: str | Sequence[str] | None = None,
polar: bool = False,
extend_branches: bool = False,
angled_branches: bool = False,
Expand Down Expand Up @@ -520,7 +520,7 @@ def tree(
tree
The `obst` key or keys of the trees to plot. If `None`, all trees are plotted.
nodes
Either "all", "leaves", "internal", or a list of nodes to plot.
Either "all", "leaves", "internal", or a list of nodes to plot. Defaults to "internal" if node color, style, or size is set.
polar
Whether to plot the tree in polar coordinates.
extend_branches
Expand Down Expand Up @@ -562,6 +562,8 @@ def tree(
ax=ax,
)
# Plot nodes
if nodes is None and (node_color != "black" or node_style != "o" or node_size != 10):
nodes = "internal"
if nodes:
ax = _nodes(
tdata,
Expand Down
8 changes: 0 additions & 8 deletions src/pycea/tl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ def _format_as_list(obj):
return obj


def _check_tree_overlap(tdata, tree_keys):
"""If overlap is allowed there can only be one tree"""
n_trees = len(tdata.obst.keys())
if (n_trees > 1) and tdata.allow_overlap and len(tree_keys) != 1:
raise ValueError("Must specify a singe tree if tdata.allow_overlap is True.")
return


def _set_distances_and_connectivities(tdata, key_added, dist, connect, update):
"""Set distances and connectivities in tdata"""
dist_key = f"{key_added}_distances"
Expand Down
2 changes: 0 additions & 2 deletions src/pycea/tl/tree_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ._metrics import _get_tree_metric, _TreeMetric
from ._utils import (
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_format_keys,
_set_distances_and_connectivities,
Expand Down Expand Up @@ -168,7 +167,6 @@ def tree_distance(
trees = get_trees(tdata, tree_keys)
metric_fn = _get_tree_metric(metric)
single_obs = False
_check_tree_overlap(tdata, tree_keys)
if update:
_check_previous_params(tdata, {"metric": metric}, key_added, ["neighbors", "distances"])
# Get set of pairs for each tree
Expand Down
6 changes: 2 additions & 4 deletions src/pycea/tl/tree_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ._utils import (
_assert_param_xor,
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_set_distances_and_connectivities,
_set_random_state,
Expand Down Expand Up @@ -50,7 +49,7 @@ def _bfs_by_distance(tree, start_node, n_neighbors, max_dist, metric, depth_key)
if len(neighbors) >= n_neighbors:
break
heapq.heappush(queue, (child_distance, child))
visited.add(child)
visited.add(child)
# Add parents to queue
for parent in nx.ancestors(tree, node):
if parent not in visited:
Expand All @@ -60,7 +59,7 @@ def _bfs_by_distance(tree, start_node, n_neighbors, max_dist, metric, depth_key)
parent_distance = tree.nodes[parent][depth_key]
if parent_distance <= max_dist:
heapq.heappush(queue, (parent_distance, parent))
visited.add(parent)
visited.add(parent)
return neighbors, neighbor_distances


Expand Down Expand Up @@ -143,7 +142,6 @@ def tree_neighbors(
_assert_param_xor({"n_neighbors": n_neighbors, "max_dist": max_dist})
_ = _get_tree_metric(metric)
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
if update:
_check_previous_params(tdata, {"metric": metric}, key_added, ["neighbors", "distances"])
# Neighbors of a single leaf
Expand Down
2 changes: 1 addition & 1 deletion src/pycea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def check_tree_has_key(tree: nx.DiGraph, key: str):
sampled_nodes = random.sample(list(tree.nodes), min(10, len(tree.nodes)))
for node in sampled_nodes:
if key not in tree.nodes[node]:
message = f"Tree does not have {key} attribute."
message = f"Tree nodes to not have {key} attribute."
if key == "depth":
message += " You can run `pycea.pp.add_depth` to add depth attribute."
raise ValueError(message)
Expand Down
Loading