Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add depth #6

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

## Preprocessing

```{eval-rst}
.. module:: pycea.pp
.. currentmodule:: pycea

.. autosummary::
:toctree: generated

pp.add_depth
```

## Tools

```{eval-rst}
Expand Down
2 changes: 1 addition & 1 deletion src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def annotation(
**kwargs,
) -> Axes:
"""\
Plot leaf annotations.
Plot leaf annotations for a tree.

Parameters
----------
Expand Down
1 change: 1 addition & 0 deletions src/pycea/pp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .setup_tree import add_depth
39 changes: 39 additions & 0 deletions src/pycea/pp/setup_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from collections.abc import Sequence

import networkx as nx
import treedata as td

from pycea.utils import get_keyed_node_data, get_root, get_trees


def _add_depth(tree, depth_key):
"""Adds a depth attribute to the nodes of a tree."""
root = get_root(tree)
depths = nx.single_source_shortest_path_length(tree, root)
nx.set_node_attributes(tree, depths, depth_key)


def add_depth(
tdata: td.TreeData, depth_key: str = "depth", tree: str | Sequence[str] | None = None, copy: bool = False
):
"""Adds a depth attribute to the nodes of a tree.

Parameters
----------
tdata
TreeData object.
depth_key
Node attribute key to store the depth.
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
If True, returns a pd.DataFrame node depths.
"""
tree_keys = tree
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
_add_depth(tree, depth_key)
if copy:
return get_keyed_node_data(tdata, depth_key)
4 changes: 2 additions & 2 deletions src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _clade_name_generator():


def _clades(tree, depth, depth_key, clades, clade_key, name_generator):
"""Identifies clades in a tree."""
"""Marks clades in a tree."""
if (depth is not None) and (clades is None):
nodes = _nodes_at_depth(tree, get_root(tree), [], depth, depth_key)
clades = dict(zip(nodes, name_generator))
Expand Down Expand Up @@ -61,7 +61,7 @@ def clades(
tree: str | Sequence[str] | None = None,
copy: bool = False,
) -> None | Mapping:
"""Identifies clades in a tree.
"""Marks clades in a tree.

Parameters
----------
Expand Down
7 changes: 5 additions & 2 deletions src/pycea/tl/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ def _sort_tree(tree, key, reverse=False):
try:
sorted_children = sorted(tree.successors(node), key=lambda x: tree.nodes[x][key], reverse=reverse)
except KeyError as err:
raise KeyError(f"Node {next(tree.successors(node))} does not have a {key} attribute.") from err
raise KeyError(
f"Node {next(tree.successors(node))} does not have a {key} attribute.",
"You may need to call `ancestral_states` to infer internal node values",
) from err
tree.remove_edges_from([(node, child) for child in tree.successors(node)])
tree.add_edges_from([(node, child) for child in sorted_children])
return tree


def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequence[str] | None = None) -> None:
"""Reorders branches based on a given key.
"""Reorders branches based on a node attribute.

Parameters
----------
Expand Down
26 changes: 26 additions & 0 deletions tests/test_setup_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import networkx as nx
import pandas as pd
import pytest
import treedata as td

from pycea.pp.setup_tree import add_depth


@pytest.fixture
def tdata():
tree1 = nx.DiGraph([("root", "A"), ("root", "B"), ("B", "C"), ("B", "D")])
tree2 = nx.DiGraph([("root", "E"), ("root", "F")])
tdata = td.TreeData(obs=pd.DataFrame(index=["A", "C", "D", "E", "F"]), obst={"tree1": tree1, "tree2": tree2})
yield tdata


def test_add_depth(tdata):
depths = add_depth(tdata, depth_key="depth", copy=True)
assert depths.loc[("tree1", "root"), "depth"] == 0
assert depths.loc[("tree1", "C"), "depth"] == 2
assert tdata.obst["tree1"].nodes["root"]["depth"] == 0
assert tdata.obst["tree1"].nodes["C"]["depth"] == 2


if __name__ == "__main__":
pytest.main(["-v", __file__])
Loading