diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 23a5340..0cdb04b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -14,3 +14,5 @@ python: path: . extra_requirements: - doc +formats: + - pdf \ No newline at end of file diff --git a/src/pycea/pl/_utils.py b/src/pycea/pl/_utils.py index 5fd5a2b..55457b8 100755 --- a/src/pycea/pl/_utils.py +++ b/src/pycea/pl/_utils.py @@ -13,7 +13,7 @@ import numpy as np from scanpy.plotting import palettes -from pycea.utils import get_leaves, get_root +from pycea.utils import check_tree_has_key, get_leaves, get_root def layout_nodes_and_branches( @@ -117,13 +117,12 @@ def layout_trees( leaves = [] depths = [] for _, tree in trees.items(): + check_tree_has_key(tree, depth_key) tree_leaves = get_leaves(tree) leaves.extend(tree_leaves) depths.extend(tree.nodes[leaf].get(depth_key) for leaf in tree_leaves) if len(depths) != len(leaves): - raise ValueError( - f"Tree does not have {depth_key} attribute. You can run `pycea.pp.add_depth` to add depth attribute." - ) + raise ValueError(f"Every node in the tree must have a {depth_key} attribute. ") max_depth = max(depths) n_leaves = len(leaves) leaf_coords = {} diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index ef0b7b4..1d82dd6 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -79,9 +79,9 @@ def branches( """ # noqa: D205 # Setup tree_keys = tree - if not ax: - ax = plt.gca() - if (ax.name == "polar" and not polar) or (ax.name != "polar" and polar): + if ax is None: + fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None) + elif (ax.name == "polar" and not polar) or (ax.name != "polar" and polar): warnings.warn("Polar setting of axes does not match requested type. Creating new axes.", stacklevel=2) fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None) kwargs = kwargs if kwargs else {} @@ -492,7 +492,7 @@ def tree( tdata: td.TreeData, keys: str | Sequence[str] = None, tree: str | Sequence[str] | None = None, - nodes: str | Sequence[str] = None, + nodes: str | Sequence[str] | None = None, polar: bool = False, extend_branches: bool = False, angled_branches: bool = False, @@ -520,7 +520,7 @@ def tree( tree The `obst` key or keys of the trees to plot. If `None`, all trees are plotted. nodes - Either "all", "leaves", "internal", or a list of nodes to plot. + Either "all", "leaves", "internal", or a list of nodes to plot. Defaults to "internal" if node color, style, or size is set. polar Whether to plot the tree in polar coordinates. extend_branches @@ -562,6 +562,8 @@ def tree( ax=ax, ) # Plot nodes + if nodes is None and (node_color != "black" or node_style != "o" or node_size != 10): + nodes = "internal" if nodes: ax = _nodes( tdata, diff --git a/src/pycea/tl/_utils.py b/src/pycea/tl/_utils.py index 37d6752..be1c6bc 100755 --- a/src/pycea/tl/_utils.py +++ b/src/pycea/tl/_utils.py @@ -62,14 +62,6 @@ def _format_as_list(obj): return obj -def _check_tree_overlap(tdata, tree_keys): - """If overlap is allowed there can only be one tree""" - n_trees = len(tdata.obst.keys()) - if (n_trees > 1) and tdata.allow_overlap and len(tree_keys) != 1: - raise ValueError("Must specify a singe tree if tdata.allow_overlap is True.") - return - - def _set_distances_and_connectivities(tdata, key_added, dist, connect, update): """Set distances and connectivities in tdata""" dist_key = f"{key_added}_distances" diff --git a/src/pycea/tl/tree_distance.py b/src/pycea/tl/tree_distance.py index aaced70..a27bff0 100755 --- a/src/pycea/tl/tree_distance.py +++ b/src/pycea/tl/tree_distance.py @@ -15,7 +15,6 @@ from ._metrics import _get_tree_metric, _TreeMetric from ._utils import ( _check_previous_params, - _check_tree_overlap, _csr_data_mask, _format_keys, _set_distances_and_connectivities, @@ -168,7 +167,6 @@ def tree_distance( trees = get_trees(tdata, tree_keys) metric_fn = _get_tree_metric(metric) single_obs = False - _check_tree_overlap(tdata, tree_keys) if update: _check_previous_params(tdata, {"metric": metric}, key_added, ["neighbors", "distances"]) # Get set of pairs for each tree diff --git a/src/pycea/tl/tree_neighbors.py b/src/pycea/tl/tree_neighbors.py index 5cb205d..f26a8dc 100755 --- a/src/pycea/tl/tree_neighbors.py +++ b/src/pycea/tl/tree_neighbors.py @@ -14,7 +14,6 @@ from ._utils import ( _assert_param_xor, _check_previous_params, - _check_tree_overlap, _csr_data_mask, _set_distances_and_connectivities, _set_random_state, @@ -50,7 +49,7 @@ def _bfs_by_distance(tree, start_node, n_neighbors, max_dist, metric, depth_key) if len(neighbors) >= n_neighbors: break heapq.heappush(queue, (child_distance, child)) - visited.add(child) + visited.add(child) # Add parents to queue for parent in nx.ancestors(tree, node): if parent not in visited: @@ -60,7 +59,7 @@ def _bfs_by_distance(tree, start_node, n_neighbors, max_dist, metric, depth_key) parent_distance = tree.nodes[parent][depth_key] if parent_distance <= max_dist: heapq.heappush(queue, (parent_distance, parent)) - visited.add(parent) + visited.add(parent) return neighbors, neighbor_distances @@ -143,7 +142,6 @@ def tree_neighbors( _assert_param_xor({"n_neighbors": n_neighbors, "max_dist": max_dist}) _ = _get_tree_metric(metric) tree_keys = tree - _check_tree_overlap(tdata, tree_keys) if update: _check_previous_params(tdata, {"metric": metric}, key_added, ["neighbors", "distances"]) # Neighbors of a single leaf diff --git a/src/pycea/utils.py b/src/pycea/utils.py index 551b574..f6beafc 100755 --- a/src/pycea/utils.py +++ b/src/pycea/utils.py @@ -37,7 +37,7 @@ def check_tree_has_key(tree: nx.DiGraph, key: str): sampled_nodes = random.sample(list(tree.nodes), min(10, len(tree.nodes))) for node in sampled_nodes: if key not in tree.nodes[node]: - message = f"Tree does not have {key} attribute." + message = f"Tree nodes to not have {key} attribute." if key == "depth": message += " You can run `pycea.pp.add_depth` to add depth attribute." raise ValueError(message)